ps/rdma/allshards_ps_client.cc
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #include "allshards_ps_client.h" | ||
| 2 | |||
| 3 | #include <algorithm> | ||
| 4 | #include <boost/coroutine2/all.hpp> | ||
| 5 | #include <cstring> | ||
| 6 | #include <limits> | ||
| 7 | #include <memory> | ||
| 8 | #include <stdexcept> | ||
| 9 | #include <thread> | ||
| 10 | #include <vector> | ||
| 11 | |||
| 12 | #include "base/hash.h" | ||
| 13 | #include "ps/rdma/rdma_common.h" | ||
| 14 | |||
| 15 | DECLARE_int32(value_size); | ||
| 16 | DECLARE_int32(max_kv_num_per_request); | ||
| 17 | |||
| 18 | 10 | AllShardsParameterClientWrapper::AllShardsParameterClientWrapper( | |
| 19 | const std::vector<BaseParameterClient*>& clients, | ||
| 20 | int num_shards, | ||
| 21 | const std::string& hash_method, | ||
| 22 | 10 | const std::vector<int>& shard_ids) | |
| 23 | : BaseParameterClient("", 0, 0), | ||
| 24 | 10 | clients_(clients), | |
| 25 | 10 | num_shards_(num_shards), | |
| 26 |
4/8✓ Branch 2 taken 10 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 10 times.
✗ Branch 6 not taken.
✓ Branch 10 taken 10 times.
✗ Branch 11 not taken.
✓ Branch 13 taken 10 times.
✗ Branch 14 not taken.
|
10 | hash_method_(hash_method) { |
| 27 |
2/8✓ Branch 4 taken 10 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 10 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
|
10 | CHECK_EQ(static_cast<int>(clients_.size()), num_shards_); |
| 28 |
2/2✓ Branch 1 taken 2 times.
✓ Branch 2 taken 8 times.
|
10 | if (!shard_ids.empty()) { |
| 29 |
2/8✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
|
2 | CHECK_EQ(static_cast<int>(shard_ids.size()), num_shards_); |
| 30 |
2/2✓ Branch 0 taken 4 times.
✓ Branch 1 taken 2 times.
|
6 | for (int i = 0; i < num_shards_; ++i) { |
| 31 |
1/2✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
|
4 | shard_to_client_index_[shard_ids[static_cast<std::size_t>(i)]] = i; |
| 32 | } | ||
| 33 | } else { | ||
| 34 |
2/2✓ Branch 0 taken 16 times.
✓ Branch 1 taken 8 times.
|
24 | for (int i = 0; i < num_shards_; ++i) { |
| 35 |
1/2✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
|
16 | shard_to_client_index_[i] = i; |
| 36 | } | ||
| 37 | } | ||
| 38 | 10 | } | |
| 39 | |||
| 40 | 38 | int AllShardsParameterClientWrapper::PartitionKey(uint64_t key) const { | |
| 41 |
1/6✗ Branch 3 not taken.
✓ Branch 4 taken 38 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
38 | CHECK_GT(num_shards_, 0); |
| 42 |
1/2✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
|
38 | if (hash_method_ == "city_hash") { |
| 43 | 38 | return static_cast<int>(GetHash(key) % static_cast<uint64_t>(num_shards_)); | |
| 44 | } | ||
| 45 | ✗ | if (hash_method_ == "simple_mod") { | |
| 46 | ✗ | return static_cast<int>(key % static_cast<uint64_t>(num_shards_)); | |
| 47 | } | ||
| 48 | ✗ | throw std::runtime_error("unsupported shard hash method: " + hash_method_); | |
| 49 | } | ||
| 50 | |||
| 51 | std::vector<AllShardsParameterClientWrapper::ShardChunk> | ||
| 52 | 6 | AllShardsParameterClientWrapper::BuildChunks( | |
| 53 | base::ConstArray<uint64_t> keys) const { | ||
| 54 |
1/2✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
|
6 | std::vector<std::vector<uint64_t>> shard_keys(num_shards_); |
| 55 |
1/2✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
|
6 | std::vector<std::vector<std::size_t>> shard_positions(num_shards_); |
| 56 | |||
| 57 |
2/2✓ Branch 1 taken 22 times.
✓ Branch 2 taken 6 times.
|
28 | for (std::size_t i = 0; i < keys.Size(); ++i) { |
| 58 |
1/2✓ Branch 2 taken 22 times.
✗ Branch 3 not taken.
|
22 | const int shard = PartitionKey(keys[i]); |
| 59 |
1/2✓ Branch 3 taken 22 times.
✗ Branch 4 not taken.
|
22 | shard_keys[static_cast<std::size_t>(shard)].push_back(keys[i]); |
| 60 |
1/2✓ Branch 2 taken 22 times.
✗ Branch 3 not taken.
|
22 | shard_positions[static_cast<std::size_t>(shard)].push_back(i); |
| 61 | } | ||
| 62 | |||
| 63 | 6 | std::vector<ShardChunk> chunks; | |
| 64 |
2/2✓ Branch 0 taken 12 times.
✓ Branch 1 taken 6 times.
|
18 | for (int shard = 0; shard < num_shards_; ++shard) { |
| 65 |
1/2✓ Branch 1 taken 12 times.
✗ Branch 2 not taken.
|
12 | const int client_index = shard_to_client_index_.at(shard); |
| 66 | 12 | for (std::size_t offset = 0; | |
| 67 |
2/2✓ Branch 2 taken 16 times.
✓ Branch 3 taken 12 times.
|
28 | offset < shard_keys[static_cast<std::size_t>(shard)].size(); |
| 68 | 16 | offset += static_cast<std::size_t>(FLAGS_max_kv_num_per_request)) { | |
| 69 | 16 | const std::size_t end = std::min( | |
| 70 | 32 | offset + static_cast<std::size_t>(FLAGS_max_kv_num_per_request), | |
| 71 | 16 | shard_keys[static_cast<std::size_t>(shard)].size()); | |
| 72 | 16 | ShardChunk chunk; | |
| 73 | 16 | chunk.shard_id = shard; | |
| 74 | 16 | chunk.client_index = client_index; | |
| 75 |
1/2✓ Branch 3 taken 16 times.
✗ Branch 4 not taken.
|
32 | chunk.keys.assign( |
| 76 | 16 | shard_keys[static_cast<std::size_t>(shard)].begin() + offset, | |
| 77 | 16 | shard_keys[static_cast<std::size_t>(shard)].begin() + end); | |
| 78 |
1/2✓ Branch 3 taken 16 times.
✗ Branch 4 not taken.
|
32 | chunk.positions.assign( |
| 79 | 16 | shard_positions[static_cast<std::size_t>(shard)].begin() + offset, | |
| 80 | 16 | shard_positions[static_cast<std::size_t>(shard)].begin() + end); | |
| 81 |
1/2✓ Branch 2 taken 16 times.
✗ Branch 3 not taken.
|
16 | chunks.push_back(std::move(chunk)); |
| 82 | 16 | } | |
| 83 | } | ||
| 84 | 12 | return chunks; | |
| 85 | 6 | } | |
| 86 | |||
| 87 | 4 | bool AllShardsParameterClientWrapper::FinalizeBatchIfNeeded( | |
| 88 | BatchRequest* batch) { | ||
| 89 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
|
4 | if (batch->assembled) { |
| 90 | ✗ | return batch->status_code == | |
| 91 | ✗ | static_cast<std::int32_t>(petps::RpcStatus::kOk); | |
| 92 | } | ||
| 93 | |||
| 94 | 4 | batch->status_code = static_cast<std::int32_t>(petps::RpcStatus::kOk); | |
| 95 |
2/2✓ Branch 5 taken 12 times.
✓ Branch 6 taken 2 times.
|
14 | for (const auto& pending : batch->shard_rpcs) { |
| 96 | 12 | const auto* status_word = petps::FixedSlotStatusWord( | |
| 97 | 12 | pending.recv_buffer, pending.key_count, FLAGS_value_size); | |
| 98 |
2/2✓ Branch 0 taken 2 times.
✓ Branch 1 taken 10 times.
|
12 | if (*status_word != static_cast<std::int32_t>(petps::RpcStatus::kOk)) { |
| 99 | 2 | batch->status_code = *status_word; | |
| 100 | 2 | break; | |
| 101 | } | ||
| 102 | } | ||
| 103 | |||
| 104 | 4 | const int embedding_dim = FLAGS_value_size / sizeof(float); | |
| 105 |
2/2✓ Branch 0 taken 2 times.
✓ Branch 1 taken 2 times.
|
4 | if (batch->status_code == static_cast<std::int32_t>(petps::RpcStatus::kOk)) { |
| 106 |
2/2✓ Branch 5 taken 8 times.
✓ Branch 6 taken 2 times.
|
10 | for (const auto& pending : batch->shard_rpcs) { |
| 107 | 8 | const float* shard_values = | |
| 108 | static_cast<const float*>(pending.recv_buffer); | ||
| 109 |
2/2✓ Branch 1 taken 14 times.
✓ Branch 2 taken 8 times.
|
22 | for (std::size_t i = 0; i < pending.original_positions.size(); ++i) { |
| 110 | 28 | std::memcpy( | |
| 111 | 14 | batch->user_buffer + pending.original_positions[i] * embedding_dim, | |
| 112 | 14 | shard_values + i * embedding_dim, | |
| 113 | FLAGS_value_size); | ||
| 114 | } | ||
| 115 | } | ||
| 116 | } | ||
| 117 | |||
| 118 | 4 | auto* batch_status_word = reinterpret_cast<std::int32_t*>( | |
| 119 | 4 | reinterpret_cast<char*>(batch->user_buffer) + | |
| 120 | 4 | batch->total_key_count * static_cast<std::size_t>(FLAGS_value_size)); | |
| 121 | 4 | *batch_status_word = batch->status_code; | |
| 122 | 4 | batch->assembled = true; | |
| 123 | 4 | return batch->status_code == static_cast<std::int32_t>(petps::RpcStatus::kOk); | |
| 124 | } | ||
| 125 | |||
| 126 | 4 | void AllShardsParameterClientWrapper::WaitShardRpcsCooperatively( | |
| 127 | const std::vector<PendingShardRpc>& shard_rpcs) const { | ||
| 128 | using Coroutine = boost::coroutines2::coroutine<void>; | ||
| 129 | 4 | std::vector<std::unique_ptr<Coroutine::pull_type>> waiters; | |
| 130 |
1/2✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
|
4 | waiters.reserve(shard_rpcs.size()); |
| 131 |
2/2✓ Branch 4 taken 12 times.
✓ Branch 5 taken 4 times.
|
16 | for (const auto& pending : shard_rpcs) { |
| 132 |
3/6✓ Branch 1 taken 12 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 12 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 12 times.
✗ Branch 8 not taken.
|
12 | waiters.emplace_back(std::make_unique<Coroutine::pull_type>( |
| 133 | 12 | [this, pending](Coroutine::push_type& sink) { | |
| 134 | auto* client = | ||
| 135 | 12 | clients_[static_cast<std::size_t>(pending.client_index)]; | |
| 136 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 12 times.
|
12 | while (!client->QueryRPCFinished(pending.rpc_id)) { |
| 137 | ✗ | sink(); | |
| 138 | } | ||
| 139 | 12 | client->WaitRPCFinish(pending.rpc_id); | |
| 140 | 12 | })); | |
| 141 | } | ||
| 142 | |||
| 143 |
2/2✓ Branch 1 taken 4 times.
✓ Branch 2 taken 4 times.
|
8 | while (!waiters.empty()) { |
| 144 |
2/2✓ Branch 3 taken 12 times.
✓ Branch 4 taken 4 times.
|
16 | for (auto it = waiters.begin(); it != waiters.end();) { |
| 145 | 12 | auto& waiter = *it; | |
| 146 |
1/2✗ Branch 2 not taken.
✓ Branch 3 taken 12 times.
|
12 | if (*waiter) { |
| 147 | ✗ | (*waiter)(); | |
| 148 | } | ||
| 149 |
1/2✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
|
12 | if (!*waiter) { |
| 150 |
1/2✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
|
12 | it = waiters.erase(it); |
| 151 | } else { | ||
| 152 | ✗ | ++it; | |
| 153 | } | ||
| 154 | } | ||
| 155 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 4 times.
|
4 | if (!waiters.empty()) { |
| 156 | ✗ | std::this_thread::yield(); | |
| 157 | } | ||
| 158 | } | ||
| 159 | 4 | } | |
| 160 | |||
| 161 | 2 | int AllShardsParameterClientWrapper::GetParameter( | |
| 162 | base::ConstArray<uint64_t> keys, std::vector<std::vector<float>>* values) { | ||
| 163 | 2 | values->clear(); | |
| 164 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
|
2 | if (keys.Size() == 0) { |
| 165 | ✗ | return 0; | |
| 166 | } | ||
| 167 | |||
| 168 | 2 | const int embedding_dim = FLAGS_value_size / sizeof(float); | |
| 169 |
1/2✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
|
2 | std::vector<float> flat(keys.Size() * embedding_dim + 1, 0.0f); |
| 170 |
1/2✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
2 | int rpc_id = GetParameter(keys, flat.data(), false, 0); |
| 171 |
1/2✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
|
2 | WaitRPCFinish(rpc_id); |
| 172 | const auto* status_word = | ||
| 173 | 2 | petps::FixedSlotStatusWord(flat.data(), keys.Size(), FLAGS_value_size); | |
| 174 |
1/2✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
|
2 | if (*status_word != static_cast<std::int32_t>(petps::RpcStatus::kOk)) { |
| 175 |
1/2✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
|
2 | RevokeRPCResource(rpc_id); |
| 176 | 2 | return -1; | |
| 177 | } | ||
| 178 | |||
| 179 | ✗ | petps::CopyFlatRowsToVectors( | |
| 180 | ✗ | flat.data(), | |
| 181 | ✗ | keys.Size(), | |
| 182 | static_cast<std::size_t>(embedding_dim), | ||
| 183 | values); | ||
| 184 | ✗ | RevokeRPCResource(rpc_id); | |
| 185 | ✗ | return 0; | |
| 186 | 2 | } | |
| 187 | |||
| 188 | 6 | int AllShardsParameterClientWrapper::GetParameter( | |
| 189 | base::ConstArray<uint64_t> keys, | ||
| 190 | float* values, | ||
| 191 | bool isAsync, | ||
| 192 | int async_req_id) { | ||
| 193 | 6 | BatchRequest batch; | |
| 194 | 6 | batch.user_buffer = values; | |
| 195 | 6 | batch.total_key_count = keys.Size(); | |
| 196 | auto* batch_status_word = | ||
| 197 | 6 | petps::FixedSlotStatusWord(values, keys.Size(), FLAGS_value_size); | |
| 198 | 6 | *batch_status_word = static_cast<std::int32_t>(petps::RpcStatus::kPending); | |
| 199 | |||
| 200 |
3/4✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✓ Branch 7 taken 16 times.
✓ Branch 8 taken 6 times.
|
22 | for (const auto& chunk : BuildChunks(keys)) { |
| 201 | 16 | void* recv = clients_[chunk.client_index]->GetReceiveBuffer( | |
| 202 |
1/2✓ Branch 2 taken 16 times.
✗ Branch 3 not taken.
|
16 | chunk.keys.size() * static_cast<std::size_t>(FLAGS_value_size) + |
| 203 | sizeof(std::int32_t)); | ||
| 204 |
1/2✓ Branch 2 taken 16 times.
✗ Branch 3 not taken.
|
32 | int rpc_id = clients_[chunk.client_index]->GetParameter( |
| 205 | 16 | base::ConstArray<uint64_t>(chunk.keys), | |
| 206 | static_cast<float*>(recv), | ||
| 207 | isAsync, | ||
| 208 | async_req_id); | ||
| 209 |
1/2✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
|
16 | batch.shard_rpcs.push_back(PendingShardRpc{ |
| 210 | 16 | chunk.shard_id, | |
| 211 | 16 | chunk.client_index, | |
| 212 | rpc_id, | ||
| 213 |
1/2✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
|
16 | chunk.positions, |
| 214 | recv, | ||
| 215 | 16 | chunk.keys.size(), | |
| 216 | }); | ||
| 217 | 6 | } | |
| 218 | |||
| 219 | 6 | std::uint64_t batch_id = 0; | |
| 220 | { | ||
| 221 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | std::lock_guard<std::mutex> guard(batches_mu_); |
| 222 | 6 | batch_id = batch_rpc_id_acc_++; | |
| 223 | 6 | if (batch_id > | |
| 224 |
2/2✓ Branch 1 taken 2 times.
✓ Branch 2 taken 4 times.
|
6 | static_cast<std::uint64_t>(std::numeric_limits<int>::max())) { |
| 225 |
2/4✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
|
2 | throw std::runtime_error("allshards batch rpc id overflow int range: " + |
| 226 |
1/2✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
4 | std::to_string(batch_id)); |
| 227 | } | ||
| 228 |
1/2✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
|
4 | batches_[batch_id] = std::move(batch); |
| 229 | 6 | } | |
| 230 |
1/2✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
|
4 | if (!isAsync) { |
| 231 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | WaitRPCFinish(static_cast<int>(batch_id)); |
| 232 | } | ||
| 233 | 4 | return static_cast<int>(batch_id); | |
| 234 | 6 | } | |
| 235 | |||
| 236 | 2 | void AllShardsParameterClientWrapper::InitThread() { | |
| 237 |
2/2✓ Branch 5 taken 4 times.
✓ Branch 6 taken 2 times.
|
6 | for (auto* client : clients_) { |
| 238 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | client->InitThread(); |
| 239 | } | ||
| 240 | 2 | } | |
| 241 | |||
| 242 | ✗ | void AllShardsParameterClientWrapper::Barrier(const std::string& ss, int k) { | |
| 243 | ✗ | CHECK(!clients_.empty()); | |
| 244 | ✗ | clients_.front()->Barrier(ss, k); | |
| 245 | ✗ | } | |
| 246 | |||
| 247 | ✗ | void* AllShardsParameterClientWrapper::GetReceiveBuffer(size_t size) { | |
| 248 | ✗ | return new char[size]; | |
| 249 | } | ||
| 250 | |||
| 251 | ✗ | bool AllShardsParameterClientWrapper::QueryRPCFinished(int rpc_id) { | |
| 252 | ✗ | std::lock_guard<std::mutex> guard(batches_mu_); | |
| 253 | ✗ | auto it = batches_.find(rpc_id); | |
| 254 | ✗ | CHECK(it != batches_.end()); | |
| 255 | |||
| 256 | ✗ | for (const auto& pending : it->second.shard_rpcs) { | |
| 257 | ✗ | if (!clients_[pending.client_index]->QueryRPCFinished(pending.rpc_id)) { | |
| 258 | ✗ | return false; | |
| 259 | } | ||
| 260 | } | ||
| 261 | |||
| 262 | ✗ | return FinalizeBatchIfNeeded(&it->second); | |
| 263 | ✗ | } | |
| 264 | |||
| 265 | 8 | void AllShardsParameterClientWrapper::WaitRPCFinish(int rpc_id) { | |
| 266 | 8 | std::vector<PendingShardRpc> shard_rpcs; | |
| 267 | { | ||
| 268 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | std::lock_guard<std::mutex> guard(batches_mu_); |
| 269 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | auto it = batches_.find(rpc_id); |
| 270 |
2/12✗ Branch 2 not taken.
✓ Branch 3 taken 8 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 8 times.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
|
8 | CHECK(it != batches_.end()); |
| 271 |
2/2✓ Branch 1 taken 4 times.
✓ Branch 2 taken 4 times.
|
8 | if (it->second.assembled) { |
| 272 | 4 | return; | |
| 273 | } | ||
| 274 |
1/2✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
|
4 | shard_rpcs = it->second.shard_rpcs; |
| 275 |
2/2✓ Branch 1 taken 4 times.
✓ Branch 2 taken 4 times.
|
8 | } |
| 276 | |||
| 277 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | WaitShardRpcsCooperatively(shard_rpcs); |
| 278 | |||
| 279 | { | ||
| 280 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | std::lock_guard<std::mutex> guard(batches_mu_); |
| 281 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | auto it = batches_.find(rpc_id); |
| 282 |
2/12✗ Branch 2 not taken.
✓ Branch 3 taken 4 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
|
4 | CHECK(it != batches_.end()); |
| 283 |
1/2✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
|
4 | FinalizeBatchIfNeeded(&it->second); |
| 284 | 4 | } | |
| 285 |
2/2✓ Branch 1 taken 4 times.
✓ Branch 2 taken 4 times.
|
8 | } |
| 286 | |||
| 287 | 4 | void AllShardsParameterClientWrapper::RevokeRPCResource(int rpc_id) { | |
| 288 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | std::lock_guard<std::mutex> guard(batches_mu_); |
| 289 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | auto it = batches_.find(rpc_id); |
| 290 |
2/12✗ Branch 2 not taken.
✓ Branch 3 taken 4 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
|
4 | CHECK(it != batches_.end()); |
| 291 | |||
| 292 |
2/2✓ Branch 6 taken 12 times.
✓ Branch 7 taken 4 times.
|
16 | for (const auto& pending : it->second.shard_rpcs) { |
| 293 |
1/2✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
|
12 | clients_[pending.client_index]->RevokeRPCResource(pending.rpc_id); |
| 294 | } | ||
| 295 | |||
| 296 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | batches_.erase(it); |
| 297 | 4 | } | |
| 298 | |||
| 299 | 4 | int AllShardsParameterClientWrapper::PutParameter( | |
| 300 | const std::vector<uint64_t>& keys, | ||
| 301 | const std::vector<std::vector<float>>& values) { | ||
| 302 |
2/8✓ Branch 5 taken 4 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 4 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
|
4 | CHECK_EQ(keys.size(), values.size()); |
| 303 | |||
| 304 |
1/2✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
|
4 | std::vector<std::vector<uint64_t>> shard_keys(num_shards_); |
| 305 |
1/2✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
|
4 | std::vector<std::vector<std::vector<float>>> shard_values(num_shards_); |
| 306 | |||
| 307 |
2/2✓ Branch 1 taken 16 times.
✓ Branch 2 taken 4 times.
|
20 | for (std::size_t i = 0; i < keys.size(); ++i) { |
| 308 |
1/2✓ Branch 2 taken 16 times.
✗ Branch 3 not taken.
|
16 | const int shard = PartitionKey(keys[i]); |
| 309 |
1/2✓ Branch 3 taken 16 times.
✗ Branch 4 not taken.
|
16 | shard_keys[static_cast<std::size_t>(shard)].push_back(keys[i]); |
| 310 |
1/2✓ Branch 3 taken 16 times.
✗ Branch 4 not taken.
|
16 | shard_values[static_cast<std::size_t>(shard)].push_back(values[i]); |
| 311 | } | ||
| 312 | |||
| 313 |
2/2✓ Branch 0 taken 8 times.
✓ Branch 1 taken 4 times.
|
12 | for (int shard = 0; shard < num_shards_; ++shard) { |
| 314 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | const int client_index = shard_to_client_index_.at(shard); |
| 315 | 8 | for (std::size_t offset = 0; | |
| 316 |
2/2✓ Branch 2 taken 8 times.
✓ Branch 3 taken 8 times.
|
16 | offset < shard_keys[static_cast<std::size_t>(shard)].size(); |
| 317 | 8 | offset += static_cast<std::size_t>(FLAGS_max_kv_num_per_request)) { | |
| 318 | 8 | const std::size_t end = std::min( | |
| 319 | 16 | offset + static_cast<std::size_t>(FLAGS_max_kv_num_per_request), | |
| 320 | 8 | shard_keys[static_cast<std::size_t>(shard)].size()); | |
| 321 | std::vector<uint64_t> key_slice( | ||
| 322 | 8 | shard_keys[static_cast<std::size_t>(shard)].begin() + offset, | |
| 323 |
1/2✓ Branch 7 taken 8 times.
✗ Branch 8 not taken.
|
16 | shard_keys[static_cast<std::size_t>(shard)].begin() + end); |
| 324 | std::vector<std::vector<float>> value_slice( | ||
| 325 | 8 | shard_values[static_cast<std::size_t>(shard)].begin() + offset, | |
| 326 |
1/2✓ Branch 7 taken 8 times.
✗ Branch 8 not taken.
|
16 | shard_values[static_cast<std::size_t>(shard)].begin() + end); |
| 327 |
1/2✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
|
8 | int rc = clients_[client_index]->PutParameter(key_slice, value_slice); |
| 328 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
|
8 | if (rc != 0) { |
| 329 | ✗ | return rc; | |
| 330 | } | ||
| 331 |
2/4✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 8 times.
✗ Branch 5 not taken.
|
8 | } |
| 332 | } | ||
| 333 | |||
| 334 | 4 | return 0; | |
| 335 | 4 | } | |
| 336 | |||
| 337 | ✗ | int AllShardsParameterClientWrapper::InitEmbeddingTable( | |
| 338 | const std::string& table_name, | ||
| 339 | std::uint64_t num_embeddings, | ||
| 340 | std::uint64_t embedding_dim) { | ||
| 341 | ✗ | for (auto* client : clients_) { | |
| 342 | const int rc = | ||
| 343 | ✗ | client->InitEmbeddingTable(table_name, num_embeddings, embedding_dim); | |
| 344 | ✗ | if (rc != 0) { | |
| 345 | ✗ | return rc; | |
| 346 | } | ||
| 347 | } | ||
| 348 | ✗ | return 0; | |
| 349 | } | ||
| 350 | |||
| 351 | ✗ | int AllShardsParameterClientWrapper::UpdateParameter( | |
| 352 | const std::string& table_name, | ||
| 353 | base::ConstArray<uint64_t> keys, | ||
| 354 | const std::vector<std::vector<float>>* grads) { | ||
| 355 | ✗ | if (grads == nullptr) { | |
| 356 | ✗ | return -1; | |
| 357 | } | ||
| 358 | ✗ | if (keys.Size() != grads->size()) { | |
| 359 | ✗ | return -1; | |
| 360 | } | ||
| 361 | ✗ | if (keys.Size() == 0) { | |
| 362 | ✗ | return 0; | |
| 363 | } | ||
| 364 | |||
| 365 | ✗ | std::vector<std::vector<uint64_t>> shard_keys(num_shards_); | |
| 366 | ✗ | std::vector<std::vector<std::vector<float>>> shard_grads(num_shards_); | |
| 367 | |||
| 368 | ✗ | for (std::size_t i = 0; i < keys.Size(); ++i) { | |
| 369 | ✗ | const int shard = PartitionKey(keys[i]); | |
| 370 | ✗ | shard_keys[static_cast<std::size_t>(shard)].push_back(keys[i]); | |
| 371 | ✗ | shard_grads[static_cast<std::size_t>(shard)].push_back((*grads)[i]); | |
| 372 | } | ||
| 373 | |||
| 374 | ✗ | for (int shard = 0; shard < num_shards_; ++shard) { | |
| 375 | ✗ | if (shard_keys[static_cast<std::size_t>(shard)].empty()) { | |
| 376 | ✗ | continue; | |
| 377 | } | ||
| 378 | ✗ | const int client_index = shard_to_client_index_.at(shard); | |
| 379 | ✗ | const int rc = clients_[client_index]->UpdateParameter( | |
| 380 | table_name, | ||
| 381 | ✗ | base::ConstArray<uint64_t>(shard_keys[static_cast<std::size_t>(shard)]), | |
| 382 | ✗ | &shard_grads[static_cast<std::size_t>(shard)]); | |
| 383 | ✗ | if (rc != 0) { | |
| 384 | ✗ | return rc; | |
| 385 | } | ||
| 386 | } | ||
| 387 | |||
| 388 | ✗ | return 0; | |
| 389 | ✗ | } | |
| 390 |