GCC Code Coverage Report


Directory: src/
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 74.9% 182 / 0 / 243
Functions: 70.6% 12 / 0 / 17
Branches: 38.9% 129 / 0 / 332

ps/rdma/allshards_ps_client.cc
Line Branch Exec Source
1 #include "allshards_ps_client.h"
2
3 #include <algorithm>
4 #include <boost/coroutine2/all.hpp>
5 #include <cstring>
6 #include <limits>
7 #include <memory>
8 #include <stdexcept>
9 #include <thread>
10 #include <vector>
11
12 #include "base/hash.h"
13 #include "ps/rdma/rdma_common.h"
14
15 DECLARE_int32(value_size);
16 DECLARE_int32(max_kv_num_per_request);
17
18 10 AllShardsParameterClientWrapper::AllShardsParameterClientWrapper(
19 const std::vector<BaseParameterClient*>& clients,
20 int num_shards,
21 const std::string& hash_method,
22 10 const std::vector<int>& shard_ids)
23 : BaseParameterClient("", 0, 0),
24 10 clients_(clients),
25 10 num_shards_(num_shards),
26
4/8
✓ Branch 2 taken 10 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 10 times.
✗ Branch 6 not taken.
✓ Branch 10 taken 10 times.
✗ Branch 11 not taken.
✓ Branch 13 taken 10 times.
✗ Branch 14 not taken.
10 hash_method_(hash_method) {
27
2/8
✓ Branch 4 taken 10 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 10 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
10 CHECK_EQ(static_cast<int>(clients_.size()), num_shards_);
28
2/2
✓ Branch 1 taken 2 times.
✓ Branch 2 taken 8 times.
10 if (!shard_ids.empty()) {
29
2/8
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
2 CHECK_EQ(static_cast<int>(shard_ids.size()), num_shards_);
30
2/2
✓ Branch 0 taken 4 times.
✓ Branch 1 taken 2 times.
6 for (int i = 0; i < num_shards_; ++i) {
31
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 shard_to_client_index_[shard_ids[static_cast<std::size_t>(i)]] = i;
32 }
33 } else {
34
2/2
✓ Branch 0 taken 16 times.
✓ Branch 1 taken 8 times.
24 for (int i = 0; i < num_shards_; ++i) {
35
1/2
✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
16 shard_to_client_index_[i] = i;
36 }
37 }
38 10 }
39
40 38 int AllShardsParameterClientWrapper::PartitionKey(uint64_t key) const {
41
1/6
✗ Branch 3 not taken.
✓ Branch 4 taken 38 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
38 CHECK_GT(num_shards_, 0);
42
1/2
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
38 if (hash_method_ == "city_hash") {
43 38 return static_cast<int>(GetHash(key) % static_cast<uint64_t>(num_shards_));
44 }
45 if (hash_method_ == "simple_mod") {
46 return static_cast<int>(key % static_cast<uint64_t>(num_shards_));
47 }
48 throw std::runtime_error("unsupported shard hash method: " + hash_method_);
49 }
50
51 std::vector<AllShardsParameterClientWrapper::ShardChunk>
52 6 AllShardsParameterClientWrapper::BuildChunks(
53 base::ConstArray<uint64_t> keys) const {
54
1/2
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
6 std::vector<std::vector<uint64_t>> shard_keys(num_shards_);
55
1/2
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
6 std::vector<std::vector<std::size_t>> shard_positions(num_shards_);
56
57
2/2
✓ Branch 1 taken 22 times.
✓ Branch 2 taken 6 times.
28 for (std::size_t i = 0; i < keys.Size(); ++i) {
58
1/2
✓ Branch 2 taken 22 times.
✗ Branch 3 not taken.
22 const int shard = PartitionKey(keys[i]);
59
1/2
✓ Branch 3 taken 22 times.
✗ Branch 4 not taken.
22 shard_keys[static_cast<std::size_t>(shard)].push_back(keys[i]);
60
1/2
✓ Branch 2 taken 22 times.
✗ Branch 3 not taken.
22 shard_positions[static_cast<std::size_t>(shard)].push_back(i);
61 }
62
63 6 std::vector<ShardChunk> chunks;
64
2/2
✓ Branch 0 taken 12 times.
✓ Branch 1 taken 6 times.
18 for (int shard = 0; shard < num_shards_; ++shard) {
65
1/2
✓ Branch 1 taken 12 times.
✗ Branch 2 not taken.
12 const int client_index = shard_to_client_index_.at(shard);
66 12 for (std::size_t offset = 0;
67
2/2
✓ Branch 2 taken 16 times.
✓ Branch 3 taken 12 times.
28 offset < shard_keys[static_cast<std::size_t>(shard)].size();
68 16 offset += static_cast<std::size_t>(FLAGS_max_kv_num_per_request)) {
69 16 const std::size_t end = std::min(
70 32 offset + static_cast<std::size_t>(FLAGS_max_kv_num_per_request),
71 16 shard_keys[static_cast<std::size_t>(shard)].size());
72 16 ShardChunk chunk;
73 16 chunk.shard_id = shard;
74 16 chunk.client_index = client_index;
75
1/2
✓ Branch 3 taken 16 times.
✗ Branch 4 not taken.
32 chunk.keys.assign(
76 16 shard_keys[static_cast<std::size_t>(shard)].begin() + offset,
77 16 shard_keys[static_cast<std::size_t>(shard)].begin() + end);
78
1/2
✓ Branch 3 taken 16 times.
✗ Branch 4 not taken.
32 chunk.positions.assign(
79 16 shard_positions[static_cast<std::size_t>(shard)].begin() + offset,
80 16 shard_positions[static_cast<std::size_t>(shard)].begin() + end);
81
1/2
✓ Branch 2 taken 16 times.
✗ Branch 3 not taken.
16 chunks.push_back(std::move(chunk));
82 16 }
83 }
84 12 return chunks;
85 6 }
86
87 4 bool AllShardsParameterClientWrapper::FinalizeBatchIfNeeded(
88 BatchRequest* batch) {
89
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
4 if (batch->assembled) {
90 return batch->status_code ==
91 static_cast<std::int32_t>(petps::RpcStatus::kOk);
92 }
93
94 4 batch->status_code = static_cast<std::int32_t>(petps::RpcStatus::kOk);
95
2/2
✓ Branch 5 taken 12 times.
✓ Branch 6 taken 2 times.
14 for (const auto& pending : batch->shard_rpcs) {
96 12 const auto* status_word = petps::FixedSlotStatusWord(
97 12 pending.recv_buffer, pending.key_count, FLAGS_value_size);
98
2/2
✓ Branch 0 taken 2 times.
✓ Branch 1 taken 10 times.
12 if (*status_word != static_cast<std::int32_t>(petps::RpcStatus::kOk)) {
99 2 batch->status_code = *status_word;
100 2 break;
101 }
102 }
103
104 4 const int embedding_dim = FLAGS_value_size / sizeof(float);
105
2/2
✓ Branch 0 taken 2 times.
✓ Branch 1 taken 2 times.
4 if (batch->status_code == static_cast<std::int32_t>(petps::RpcStatus::kOk)) {
106
2/2
✓ Branch 5 taken 8 times.
✓ Branch 6 taken 2 times.
10 for (const auto& pending : batch->shard_rpcs) {
107 8 const float* shard_values =
108 static_cast<const float*>(pending.recv_buffer);
109
2/2
✓ Branch 1 taken 14 times.
✓ Branch 2 taken 8 times.
22 for (std::size_t i = 0; i < pending.original_positions.size(); ++i) {
110 28 std::memcpy(
111 14 batch->user_buffer + pending.original_positions[i] * embedding_dim,
112 14 shard_values + i * embedding_dim,
113 FLAGS_value_size);
114 }
115 }
116 }
117
118 4 auto* batch_status_word = reinterpret_cast<std::int32_t*>(
119 4 reinterpret_cast<char*>(batch->user_buffer) +
120 4 batch->total_key_count * static_cast<std::size_t>(FLAGS_value_size));
121 4 *batch_status_word = batch->status_code;
122 4 batch->assembled = true;
123 4 return batch->status_code == static_cast<std::int32_t>(petps::RpcStatus::kOk);
124 }
125
126 4 void AllShardsParameterClientWrapper::WaitShardRpcsCooperatively(
127 const std::vector<PendingShardRpc>& shard_rpcs) const {
128 using Coroutine = boost::coroutines2::coroutine<void>;
129 4 std::vector<std::unique_ptr<Coroutine::pull_type>> waiters;
130
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 waiters.reserve(shard_rpcs.size());
131
2/2
✓ Branch 4 taken 12 times.
✓ Branch 5 taken 4 times.
16 for (const auto& pending : shard_rpcs) {
132
3/6
✓ Branch 1 taken 12 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 12 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 12 times.
✗ Branch 8 not taken.
12 waiters.emplace_back(std::make_unique<Coroutine::pull_type>(
133 12 [this, pending](Coroutine::push_type& sink) {
134 auto* client =
135 12 clients_[static_cast<std::size_t>(pending.client_index)];
136
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 12 times.
12 while (!client->QueryRPCFinished(pending.rpc_id)) {
137 sink();
138 }
139 12 client->WaitRPCFinish(pending.rpc_id);
140 12 }));
141 }
142
143
2/2
✓ Branch 1 taken 4 times.
✓ Branch 2 taken 4 times.
8 while (!waiters.empty()) {
144
2/2
✓ Branch 3 taken 12 times.
✓ Branch 4 taken 4 times.
16 for (auto it = waiters.begin(); it != waiters.end();) {
145 12 auto& waiter = *it;
146
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 12 times.
12 if (*waiter) {
147 (*waiter)();
148 }
149
1/2
✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
12 if (!*waiter) {
150
1/2
✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
12 it = waiters.erase(it);
151 } else {
152 ++it;
153 }
154 }
155
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 4 times.
4 if (!waiters.empty()) {
156 std::this_thread::yield();
157 }
158 }
159 4 }
160
161 2 int AllShardsParameterClientWrapper::GetParameter(
162 base::ConstArray<uint64_t> keys, std::vector<std::vector<float>>* values) {
163 2 values->clear();
164
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
2 if (keys.Size() == 0) {
165 return 0;
166 }
167
168 2 const int embedding_dim = FLAGS_value_size / sizeof(float);
169
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 std::vector<float> flat(keys.Size() * embedding_dim + 1, 0.0f);
170
1/2
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
2 int rpc_id = GetParameter(keys, flat.data(), false, 0);
171
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 WaitRPCFinish(rpc_id);
172 const auto* status_word =
173 2 petps::FixedSlotStatusWord(flat.data(), keys.Size(), FLAGS_value_size);
174
1/2
✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
2 if (*status_word != static_cast<std::int32_t>(petps::RpcStatus::kOk)) {
175
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 RevokeRPCResource(rpc_id);
176 2 return -1;
177 }
178
179 petps::CopyFlatRowsToVectors(
180 flat.data(),
181 keys.Size(),
182 static_cast<std::size_t>(embedding_dim),
183 values);
184 RevokeRPCResource(rpc_id);
185 return 0;
186 2 }
187
188 6 int AllShardsParameterClientWrapper::GetParameter(
189 base::ConstArray<uint64_t> keys,
190 float* values,
191 bool isAsync,
192 int async_req_id) {
193 6 BatchRequest batch;
194 6 batch.user_buffer = values;
195 6 batch.total_key_count = keys.Size();
196 auto* batch_status_word =
197 6 petps::FixedSlotStatusWord(values, keys.Size(), FLAGS_value_size);
198 6 *batch_status_word = static_cast<std::int32_t>(petps::RpcStatus::kPending);
199
200
3/4
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✓ Branch 7 taken 16 times.
✓ Branch 8 taken 6 times.
22 for (const auto& chunk : BuildChunks(keys)) {
201 16 void* recv = clients_[chunk.client_index]->GetReceiveBuffer(
202
1/2
✓ Branch 2 taken 16 times.
✗ Branch 3 not taken.
16 chunk.keys.size() * static_cast<std::size_t>(FLAGS_value_size) +
203 sizeof(std::int32_t));
204
1/2
✓ Branch 2 taken 16 times.
✗ Branch 3 not taken.
32 int rpc_id = clients_[chunk.client_index]->GetParameter(
205 16 base::ConstArray<uint64_t>(chunk.keys),
206 static_cast<float*>(recv),
207 isAsync,
208 async_req_id);
209
1/2
✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
16 batch.shard_rpcs.push_back(PendingShardRpc{
210 16 chunk.shard_id,
211 16 chunk.client_index,
212 rpc_id,
213
1/2
✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
16 chunk.positions,
214 recv,
215 16 chunk.keys.size(),
216 });
217 6 }
218
219 6 std::uint64_t batch_id = 0;
220 {
221
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 std::lock_guard<std::mutex> guard(batches_mu_);
222 6 batch_id = batch_rpc_id_acc_++;
223 6 if (batch_id >
224
2/2
✓ Branch 1 taken 2 times.
✓ Branch 2 taken 4 times.
6 static_cast<std::uint64_t>(std::numeric_limits<int>::max())) {
225
2/4
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
2 throw std::runtime_error("allshards batch rpc id overflow int range: " +
226
1/2
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
4 std::to_string(batch_id));
227 }
228
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 batches_[batch_id] = std::move(batch);
229 6 }
230
1/2
✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
4 if (!isAsync) {
231
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 WaitRPCFinish(static_cast<int>(batch_id));
232 }
233 4 return static_cast<int>(batch_id);
234 6 }
235
236 2 void AllShardsParameterClientWrapper::InitThread() {
237
2/2
✓ Branch 5 taken 4 times.
✓ Branch 6 taken 2 times.
6 for (auto* client : clients_) {
238
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 client->InitThread();
239 }
240 2 }
241
242 void AllShardsParameterClientWrapper::Barrier(const std::string& ss, int k) {
243 CHECK(!clients_.empty());
244 clients_.front()->Barrier(ss, k);
245 }
246
247 void* AllShardsParameterClientWrapper::GetReceiveBuffer(size_t size) {
248 return new char[size];
249 }
250
251 bool AllShardsParameterClientWrapper::QueryRPCFinished(int rpc_id) {
252 std::lock_guard<std::mutex> guard(batches_mu_);
253 auto it = batches_.find(rpc_id);
254 CHECK(it != batches_.end());
255
256 for (const auto& pending : it->second.shard_rpcs) {
257 if (!clients_[pending.client_index]->QueryRPCFinished(pending.rpc_id)) {
258 return false;
259 }
260 }
261
262 return FinalizeBatchIfNeeded(&it->second);
263 }
264
265 8 void AllShardsParameterClientWrapper::WaitRPCFinish(int rpc_id) {
266 8 std::vector<PendingShardRpc> shard_rpcs;
267 {
268
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 std::lock_guard<std::mutex> guard(batches_mu_);
269
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 auto it = batches_.find(rpc_id);
270
2/12
✗ Branch 2 not taken.
✓ Branch 3 taken 8 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 8 times.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
8 CHECK(it != batches_.end());
271
2/2
✓ Branch 1 taken 4 times.
✓ Branch 2 taken 4 times.
8 if (it->second.assembled) {
272 4 return;
273 }
274
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 shard_rpcs = it->second.shard_rpcs;
275
2/2
✓ Branch 1 taken 4 times.
✓ Branch 2 taken 4 times.
8 }
276
277
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 WaitShardRpcsCooperatively(shard_rpcs);
278
279 {
280
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 std::lock_guard<std::mutex> guard(batches_mu_);
281
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 auto it = batches_.find(rpc_id);
282
2/12
✗ Branch 2 not taken.
✓ Branch 3 taken 4 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
4 CHECK(it != batches_.end());
283
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 FinalizeBatchIfNeeded(&it->second);
284 4 }
285
2/2
✓ Branch 1 taken 4 times.
✓ Branch 2 taken 4 times.
8 }
286
287 4 void AllShardsParameterClientWrapper::RevokeRPCResource(int rpc_id) {
288
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 std::lock_guard<std::mutex> guard(batches_mu_);
289
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 auto it = batches_.find(rpc_id);
290
2/12
✗ Branch 2 not taken.
✓ Branch 3 taken 4 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
4 CHECK(it != batches_.end());
291
292
2/2
✓ Branch 6 taken 12 times.
✓ Branch 7 taken 4 times.
16 for (const auto& pending : it->second.shard_rpcs) {
293
1/2
✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
12 clients_[pending.client_index]->RevokeRPCResource(pending.rpc_id);
294 }
295
296
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 batches_.erase(it);
297 4 }
298
299 4 int AllShardsParameterClientWrapper::PutParameter(
300 const std::vector<uint64_t>& keys,
301 const std::vector<std::vector<float>>& values) {
302
2/8
✓ Branch 5 taken 4 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 4 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
4 CHECK_EQ(keys.size(), values.size());
303
304
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 std::vector<std::vector<uint64_t>> shard_keys(num_shards_);
305
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 std::vector<std::vector<std::vector<float>>> shard_values(num_shards_);
306
307
2/2
✓ Branch 1 taken 16 times.
✓ Branch 2 taken 4 times.
20 for (std::size_t i = 0; i < keys.size(); ++i) {
308
1/2
✓ Branch 2 taken 16 times.
✗ Branch 3 not taken.
16 const int shard = PartitionKey(keys[i]);
309
1/2
✓ Branch 3 taken 16 times.
✗ Branch 4 not taken.
16 shard_keys[static_cast<std::size_t>(shard)].push_back(keys[i]);
310
1/2
✓ Branch 3 taken 16 times.
✗ Branch 4 not taken.
16 shard_values[static_cast<std::size_t>(shard)].push_back(values[i]);
311 }
312
313
2/2
✓ Branch 0 taken 8 times.
✓ Branch 1 taken 4 times.
12 for (int shard = 0; shard < num_shards_; ++shard) {
314
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 const int client_index = shard_to_client_index_.at(shard);
315 8 for (std::size_t offset = 0;
316
2/2
✓ Branch 2 taken 8 times.
✓ Branch 3 taken 8 times.
16 offset < shard_keys[static_cast<std::size_t>(shard)].size();
317 8 offset += static_cast<std::size_t>(FLAGS_max_kv_num_per_request)) {
318 8 const std::size_t end = std::min(
319 16 offset + static_cast<std::size_t>(FLAGS_max_kv_num_per_request),
320 8 shard_keys[static_cast<std::size_t>(shard)].size());
321 std::vector<uint64_t> key_slice(
322 8 shard_keys[static_cast<std::size_t>(shard)].begin() + offset,
323
1/2
✓ Branch 7 taken 8 times.
✗ Branch 8 not taken.
16 shard_keys[static_cast<std::size_t>(shard)].begin() + end);
324 std::vector<std::vector<float>> value_slice(
325 8 shard_values[static_cast<std::size_t>(shard)].begin() + offset,
326
1/2
✓ Branch 7 taken 8 times.
✗ Branch 8 not taken.
16 shard_values[static_cast<std::size_t>(shard)].begin() + end);
327
1/2
✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
8 int rc = clients_[client_index]->PutParameter(key_slice, value_slice);
328
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
8 if (rc != 0) {
329 return rc;
330 }
331
2/4
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 8 times.
✗ Branch 5 not taken.
8 }
332 }
333
334 4 return 0;
335 4 }
336
337 int AllShardsParameterClientWrapper::InitEmbeddingTable(
338 const std::string& table_name,
339 std::uint64_t num_embeddings,
340 std::uint64_t embedding_dim) {
341 for (auto* client : clients_) {
342 const int rc =
343 client->InitEmbeddingTable(table_name, num_embeddings, embedding_dim);
344 if (rc != 0) {
345 return rc;
346 }
347 }
348 return 0;
349 }
350
351 int AllShardsParameterClientWrapper::UpdateParameter(
352 const std::string& table_name,
353 base::ConstArray<uint64_t> keys,
354 const std::vector<std::vector<float>>* grads) {
355 if (grads == nullptr) {
356 return -1;
357 }
358 if (keys.Size() != grads->size()) {
359 return -1;
360 }
361 if (keys.Size() == 0) {
362 return 0;
363 }
364
365 std::vector<std::vector<uint64_t>> shard_keys(num_shards_);
366 std::vector<std::vector<std::vector<float>>> shard_grads(num_shards_);
367
368 for (std::size_t i = 0; i < keys.Size(); ++i) {
369 const int shard = PartitionKey(keys[i]);
370 shard_keys[static_cast<std::size_t>(shard)].push_back(keys[i]);
371 shard_grads[static_cast<std::size_t>(shard)].push_back((*grads)[i]);
372 }
373
374 for (int shard = 0; shard < num_shards_; ++shard) {
375 if (shard_keys[static_cast<std::size_t>(shard)].empty()) {
376 continue;
377 }
378 const int client_index = shard_to_client_index_.at(shard);
379 const int rc = clients_[client_index]->UpdateParameter(
380 table_name,
381 base::ConstArray<uint64_t>(shard_keys[static_cast<std::size_t>(shard)]),
382 &shard_grads[static_cast<std::size_t>(shard)]);
383 if (rc != 0) {
384 return rc;
385 }
386 }
387
388 return 0;
389 }
390