GCC Code Coverage Report


Directory: src/
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 70.8% 366 / 0 / 517
Functions: 79.4% 27 / 0 / 34
Branches: 37.4% 314 / 0 / 840

ps/grpc/dist_grpc_ps_client.cpp
Line Branch Exec Source
1 #include "dist_grpc_ps_client.h"
2
3 #include <algorithm>
4 #include <cstring>
5 #include <future>
6 #include <thread>
7 #include <fstream>
8
9 #include "base/factory.h"
10 #include "base/log.h"
11 #include "base/timer.h"
12
13 using grpc::Channel;
14 using grpc::ClientAsyncResponseReader;
15 using grpc::ClientContext;
16 using grpc::Status;
17 using recstoreps::CommandRequest;
18 using recstoreps::CommandResponse;
19 using recstoreps::GetParameterRequest;
20 using recstoreps::GetParameterResponse;
21 using recstoreps::InitEmbeddingTableRequest;
22 using recstoreps::InitEmbeddingTableResponse;
23 using recstoreps::PSCommand;
24 using recstoreps::PutParameterRequest;
25 using recstoreps::PutParameterResponse;
26 using recstoreps::UpdateParameterRequest;
27 using recstoreps::UpdateParameterResponse;
28
29 namespace recstore {
30
31 FACTORY_REGISTER(
32 BasePSClient, distributed_grpc, DistributedGRPCParameterClient, json);
33
34 18 DistributedGRPCParameterClient::DistributedGRPCParameterClient(json config)
35
1/2
✓ Branch 2 taken 18 times.
✗ Branch 3 not taken.
18 : BasePSClient(config) {
36 18 json client_config;
37
38
2/4
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 18 times.
✗ Branch 4 not taken.
18 if (config.contains("distributed_client")) {
39
2/4
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 5 not taken.
36 LOG(INFO) << "Detected recstore config format, extracting "
40
1/2
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
18 "distributed_client section";
41
2/4
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 5 not taken.
18 client_config = config["distributed_client"];
42 } else {
43 LOG(FATAL)
44 << "Invalid config format. Expected either recstore config with "
45 "'distributed_client' section "
46 << "or direct client config with 'servers' and 'num_shards' fields";
47 }
48
49 // 验证必要字段
50
3/6
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 18 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 18 times.
36 if (!client_config.contains("servers") ||
51
2/4
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 18 times.
18 !client_config["servers"].is_array()) {
52 LOG(FATAL)
53 << "Missing or invalid 'servers' field in distributed client config";
54 }
55
56
3/6
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 18 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 18 times.
36 if (!client_config.contains("num_shards") ||
57
2/4
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 18 times.
18 !client_config["num_shards"].is_number_integer()) {
58 LOG(FATAL)
59 << "Missing or invalid 'num_shards' field in distributed client config";
60 }
61
62
2/4
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 5 not taken.
18 num_shards_ = client_config["num_shards"].get<int>();
63
1/2
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
18 max_keys_per_request_ = client_config.value("max_keys_per_request", 500);
64
1/2
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
18 hash_method_ = client_config.value("hash_method", "city_hash");
65
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
18 if (max_keys_per_request_ <= 0) {
66 LOG(FATAL) << "Invalid max_keys_per_request: " << max_keys_per_request_
67 << ", must be > 0";
68 }
69
70 // 解析服务器配置
71
2/4
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 5 not taken.
18 auto servers = client_config["servers"];
72
1/2
✓ Branch 2 taken 18 times.
✗ Branch 3 not taken.
18 server_configs_.reserve(servers.size());
73
74
2/2
✓ Branch 1 taken 36 times.
✓ Branch 2 taken 18 times.
54 for (size_t i = 0; i < servers.size(); ++i) {
75
1/2
✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
36 const auto& server = servers[i];
76
5/10
✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 36 times.
✗ Branch 4 not taken.
✓ Branch 6 taken 36 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 36 times.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✓ Branch 11 taken 36 times.
72 if (!server.contains("host") || !server.contains("port") ||
77
2/4
✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 36 times.
36 !server.contains("shard")) {
78 LOG(FATAL) << "Server config " << i
79 << " missing required fields (host, port, shard)";
80 }
81
82 36 ServerConfig cfg;
83
2/4
✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 36 times.
✗ Branch 5 not taken.
36 cfg.host = server["host"].get<std::string>();
84
2/4
✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 36 times.
✗ Branch 5 not taken.
36 cfg.port = server["port"].get<int>();
85
2/4
✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 36 times.
✗ Branch 5 not taken.
36 cfg.shard = server["shard"].get<int>();
86
87
1/2
✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
36 server_configs_.push_back(cfg);
88
1/2
✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
36 shard_to_client_index_[cfg.shard] = i;
89 36 }
90
91
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 18 times.
18 if (server_configs_.size() != static_cast<size_t>(num_shards_)) {
92 LOG(WARNING) << "Number of servers (" << server_configs_.size()
93 << ") doesn't match num_shards (" << num_shards_ << ")";
94 }
95
96
1/2
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
18 partitioned_key_buffer_.resize(num_shards_);
97
1/2
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
18 key_index_mapping_.resize(num_shards_);
98
99
1/2
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
18 InitializeClients();
100
101
3/6
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 18 times.
✗ Branch 8 not taken.
36 LOG(INFO) << "Initialized DistributedGRPCParameterClient with " << num_shards_
102
3/6
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 18 times.
✗ Branch 8 not taken.
18 << " shards, hash method: " << hash_method_;
103 18 }
104
105 26 DistributedGRPCParameterClient::~DistributedGRPCParameterClient() {}
106
107 18 void DistributedGRPCParameterClient::InitializeClients() {
108 18 clients_.clear();
109 18 clients_.reserve(server_configs_.size());
110
111
2/2
✓ Branch 5 taken 36 times.
✓ Branch 6 taken 18 times.
54 for (const auto& server_config : server_configs_) {
112 // 为每个服务器创建独立的配置
113 36 json client_config = {{"host", server_config.host},
114 36 {"port", server_config.port},
115
8/16
✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 36 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 36 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 36 times.
✗ Branch 11 not taken.
✓ Branch 14 taken 36 times.
✗ Branch 15 not taken.
✓ Branch 17 taken 36 times.
✗ Branch 18 not taken.
✓ Branch 21 taken 36 times.
✗ Branch 22 not taken.
✓ Branch 24 taken 36 times.
✗ Branch 25 not taken.
468 {"shard", server_config.shard}};
116
117
3/6
✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 36 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 36 times.
✗ Branch 8 not taken.
36 auto* raw_client = new GRPCParameterClient(client_config);
118 36 auto client = std::unique_ptr<GRPCParameterClient>(raw_client);
119
1/2
✓ Branch 2 taken 36 times.
✗ Branch 3 not taken.
36 clients_.push_back(std::move(client));
120
121
3/6
✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 36 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 36 times.
✗ Branch 8 not taken.
72 LOG(INFO) << "Created gRPC client for shard " << server_config.shard
122
5/10
✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 36 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 36 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 36 times.
✗ Branch 11 not taken.
✓ Branch 13 taken 36 times.
✗ Branch 14 not taken.
36 << " at " << server_config.host << ":" << server_config.port;
123 36 }
124 18 }
125
126 2040 int DistributedGRPCParameterClient::GetShardId(uint64_t key) const {
127
1/2
✓ Branch 1 taken 2040 times.
✗ Branch 2 not taken.
2040 if (hash_method_ == "city_hash") {
128 2040 return GetHash(key) % num_shards_;
129 } else if (hash_method_ == "simple_mod") {
130 return key % num_shards_;
131 } else {
132 LOG(ERROR) << "Unknown hash method: " << hash_method_
133 << ", using city_hash";
134 return GetHash(key) % num_shards_;
135 }
136 }
137
138 106 void DistributedGRPCParameterClient::PartitionKeys(
139 const base::ConstArray<uint64_t>& keys,
140 std::vector<std::vector<uint64_t>>& partitioned_keys) const {
141
2/2
✓ Branch 5 taken 212 times.
✓ Branch 6 taken 106 times.
318 for (auto& partition : partitioned_key_buffer_) {
142 212 partition.clear();
143 }
144
2/2
✓ Branch 5 taken 212 times.
✓ Branch 6 taken 106 times.
318 for (auto& mapping : key_index_mapping_) {
145 212 mapping.clear();
146 }
147
148
2/2
✓ Branch 1 taken 1616 times.
✓ Branch 2 taken 106 times.
1722 for (size_t i = 0; i < keys.Size(); ++i) {
149 1616 uint64_t key = keys[i];
150
1/2
✓ Branch 1 taken 1616 times.
✗ Branch 2 not taken.
1616 int shard_id = GetShardId(key);
151
152
1/2
✓ Branch 2 taken 1616 times.
✗ Branch 3 not taken.
1616 partitioned_key_buffer_[shard_id].push_back(key);
153
1/2
✓ Branch 2 taken 1616 times.
✗ Branch 3 not taken.
1616 key_index_mapping_[shard_id].push_back(i);
154 }
155
156 106 partitioned_keys = partitioned_key_buffer_;
157 106 }
158
159 46 bool DistributedGRPCParameterClient::GetParameter(
160 const base::ConstArray<uint64_t>& keys,
161 std::vector<std::vector<float>>* values) {
162
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 46 times.
46 if (keys.Size() == 0) {
163 values->clear();
164 return true;
165 }
166
167
2/4
✓ Branch 2 taken 46 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 46 times.
✗ Branch 6 not taken.
92 xmh::Timer timer("DistributedGRPCParameterClient::GetParameter");
168
169 46 std::vector<std::vector<uint64_t>> partitioned_keys;
170
1/2
✓ Branch 1 taken 46 times.
✗ Branch 2 not taken.
46 PartitionKeys(keys, partitioned_keys);
171
172 46 std::vector<std::future<int>> futures;
173
1/2
✓ Branch 2 taken 46 times.
✗ Branch 3 not taken.
46 std::vector<std::vector<std::vector<float>>> partitioned_results(num_shards_);
174
175
2/2
✓ Branch 0 taken 92 times.
✓ Branch 1 taken 46 times.
138 for (int shard_id = 0; shard_id < num_shards_; ++shard_id) {
176
2/2
✓ Branch 2 taken 8 times.
✓ Branch 3 taken 84 times.
92 if (partitioned_keys[shard_id].empty()) {
177 8 continue;
178 }
179
180
1/2
✓ Branch 1 taken 84 times.
✗ Branch 2 not taken.
84 auto it = shard_to_client_index_.find(shard_id);
181
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 84 times.
84 if (it == shard_to_client_index_.end()) {
182 LOG(ERROR) << "No client found for shard " << shard_id;
183 return false;
184 }
185
186 84 int client_index = it->second;
187 84 auto* client = clients_[client_index].get();
188
189 // 异步请求
190
2/4
✓ Branch 1 taken 84 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
84 futures.push_back(std::async(
191 168 std::launch::async, [=, &partitioned_keys, &partitioned_results]() {
192 84 const auto& shard_keys_vec = partitioned_keys[shard_id];
193 84 auto& shard_result_vec = partitioned_results[shard_id];
194 84 shard_result_vec.clear();
195 84 shard_result_vec.reserve(shard_keys_vec.size());
196
197
2/2
✓ Branch 1 taken 88 times.
✓ Branch 2 taken 84 times.
172 for (size_t start = 0; start < shard_keys_vec.size();
198 88 start += static_cast<size_t>(max_keys_per_request_)) {
199 size_t end =
200 176 std::min(start + static_cast<size_t>(max_keys_per_request_),
201 88 shard_keys_vec.size());
202 base::ConstArray<uint64_t> shard_chunk(
203 88 shard_keys_vec.data() + start, static_cast<int>(end - start));
204 88 std::vector<std::vector<float>> chunk_result;
205
2/4
✓ Branch 1 taken 88 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 88 times.
88 if (!client->GetParameter(shard_chunk, &chunk_result)) {
206 return 0;
207 }
208
1/2
✓ Branch 5 taken 88 times.
✗ Branch 6 not taken.
88 shard_result_vec.insert(shard_result_vec.end(),
209 chunk_result.begin(),
210 chunk_result.end());
211
1/2
✓ Branch 1 taken 88 times.
✗ Branch 2 not taken.
88 }
212 84 return 1;
213 }));
214 }
215
216 // 3. 等待所有请求完成
217
2/2
✓ Branch 5 taken 84 times.
✓ Branch 6 taken 46 times.
130 for (auto& future : futures) {
218
2/4
✓ Branch 1 taken 84 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
84 if (!future.get()) {
219 LOG(ERROR) << "Failed to get parameters from one of the shards";
220 return false;
221 }
222 }
223
224 // 4. 合并结果
225
1/2
✓ Branch 1 taken 46 times.
✗ Branch 2 not taken.
46 MergeResults(keys, partitioned_results, values);
226
227 46 return true;
228 46 }
229
230 46 void DistributedGRPCParameterClient::MergeResults(
231 const base::ConstArray<uint64_t>& keys,
232 const std::vector<std::vector<std::vector<float>>>& partitioned_results,
233 std::vector<std::vector<float>>* values) const {
234 46 values->clear();
235
1/2
✓ Branch 2 taken 46 times.
✗ Branch 3 not taken.
46 values->resize(keys.Size());
236
237 // 重建key -> index映射
238 46 std::unordered_map<uint64_t, size_t> key_to_result_index;
239
2/2
✓ Branch 0 taken 92 times.
✓ Branch 1 taken 46 times.
138 for (int shard_id = 0; shard_id < num_shards_; ++shard_id) {
240
2/2
✓ Branch 2 taken 368 times.
✓ Branch 3 taken 92 times.
460 for (size_t i = 0; i < key_index_mapping_[shard_id].size(); ++i) {
241 368 size_t original_index = key_index_mapping_[shard_id][i];
242
1/2
✓ Branch 2 taken 368 times.
✗ Branch 3 not taken.
368 if (i < partitioned_results[shard_id].size()) {
243
1/2
✓ Branch 4 taken 368 times.
✗ Branch 5 not taken.
368 (*values)[original_index] = partitioned_results[shard_id][i];
244 }
245 }
246 }
247 46 }
248
249 void DistributedGRPCParameterClient::MergeResultsToArray(
250 const base::ConstArray<uint64_t>& keys,
251 const std::vector<std::vector<std::vector<float>>>& partitioned_results,
252 float* values) const {
253 int emb_dim = 0;
254 for (int shard_id = 0; shard_id < num_shards_; ++shard_id) {
255 if (!partitioned_results[shard_id].empty() &&
256 !partitioned_results[shard_id][0].empty()) {
257 emb_dim = partitioned_results[shard_id][0].size();
258 break;
259 }
260 }
261
262 if (emb_dim == 0) {
263 LOG(WARNING) << "No valid embeddings found";
264 return;
265 }
266
267 // 合并结果到连续内存
268 for (int shard_id = 0; shard_id < num_shards_; ++shard_id) {
269 for (size_t i = 0; i < key_index_mapping_[shard_id].size(); ++i) {
270 size_t original_index = key_index_mapping_[shard_id][i];
271 if (i < partitioned_results[shard_id].size()) {
272 const auto& embedding = partitioned_results[shard_id][i];
273 std::copy(embedding.begin(),
274 embedding.end(),
275 values + original_index * emb_dim);
276 }
277 }
278 }
279 }
280
281 // 实现BasePSClient接口
282 32 int DistributedGRPCParameterClient::GetParameter(
283 const base::ConstArray<uint64_t>& keys, float* values) {
284 32 std::vector<std::vector<float>> result_vectors;
285
1/2
✓ Branch 1 taken 32 times.
✗ Branch 2 not taken.
32 bool success = GetParameter(keys, &result_vectors);
286
287
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 32 times.
32 if (!success) {
288 return -1;
289 }
290
291
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 32 times.
32 if (keys.Size() == 0) {
292 return 0;
293 }
294 32 int emb_dim = 0;
295
1/2
✓ Branch 5 taken 32 times.
✗ Branch 6 not taken.
32 for (const auto& row : result_vectors) {
296
1/2
✓ Branch 1 taken 32 times.
✗ Branch 2 not taken.
32 if (!row.empty()) {
297 32 emb_dim = static_cast<int>(row.size());
298 32 break;
299 }
300 }
301
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 32 times.
32 if (emb_dim == 0) {
302 LOG(WARNING) << "No valid embeddings found";
303 return 0;
304 }
305
306
2/2
✓ Branch 1 taken 92 times.
✓ Branch 2 taken 32 times.
124 for (size_t i = 0; i < result_vectors.size(); ++i) {
307 92 const auto& row = result_vectors[i];
308
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 92 times.
92 if (row.empty()) {
309 continue;
310 }
311
1/2
✓ Branch 3 taken 92 times.
✗ Branch 4 not taken.
92 std::copy(row.begin(), row.end(), values + i * emb_dim);
312 }
313 32 return 0;
314 32 }
315
316 int DistributedGRPCParameterClient::AsyncGetParameter(
317 const base::ConstArray<uint64_t>& keys, float* values) {
318 return GetParameter(keys, values);
319 }
320
321 58 int DistributedGRPCParameterClient::PutParameter(
322 const base::ConstArray<uint64_t>& keys,
323 const std::vector<std::vector<float>>& values) {
324
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 58 times.
58 if (keys.Size() != values.size()) {
325 LOG(ERROR) << "Keys and values size mismatch: " << keys.Size() << " vs "
326 << values.size();
327 return -1;
328 }
329
330 58 std::vector<std::vector<uint64_t>> partitioned_keys;
331
1/2
✓ Branch 1 taken 58 times.
✗ Branch 2 not taken.
58 PartitionKeys(keys, partitioned_keys);
332
333
1/2
✓ Branch 2 taken 58 times.
✗ Branch 3 not taken.
58 std::vector<std::vector<std::vector<float>>> partitioned_values(num_shards_);
334
2/2
✓ Branch 0 taken 116 times.
✓ Branch 1 taken 58 times.
174 for (int shard_id = 0; shard_id < num_shards_; ++shard_id) {
335
2/2
✓ Branch 2 taken 1244 times.
✓ Branch 3 taken 116 times.
1360 for (size_t i = 0; i < key_index_mapping_[shard_id].size(); ++i) {
336 1244 size_t original_index = key_index_mapping_[shard_id][i];
337
1/2
✓ Branch 3 taken 1244 times.
✗ Branch 4 not taken.
1244 partitioned_values[shard_id].push_back(values[original_index]);
338 }
339 }
340
341 // 并发put到各个分片
342 58 std::vector<std::future<int>> futures;
343
344
2/2
✓ Branch 0 taken 116 times.
✓ Branch 1 taken 58 times.
174 for (int shard_id = 0; shard_id < num_shards_; ++shard_id) {
345
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 116 times.
116 if (partitioned_keys[shard_id].empty()) {
346 continue;
347 }
348
349
1/2
✓ Branch 1 taken 116 times.
✗ Branch 2 not taken.
116 auto it = shard_to_client_index_.find(shard_id);
350
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 116 times.
116 if (it == shard_to_client_index_.end()) {
351 LOG(ERROR) << "No client found for shard " << shard_id;
352 return -1;
353 }
354
355 116 int client_index = it->second;
356 116 auto* client = clients_[client_index].get();
357
358
2/4
✓ Branch 1 taken 116 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 116 times.
✗ Branch 5 not taken.
116 futures.push_back(std::async(
359 232 std::launch::async, [=, &partitioned_keys, &partitioned_values]() {
360 116 const auto& shard_keys_vec = partitioned_keys[shard_id];
361 116 const auto& shard_vals_vec = partitioned_values[shard_id];
362
2/2
✓ Branch 1 taken 148 times.
✓ Branch 2 taken 116 times.
264 for (size_t start = 0; start < shard_keys_vec.size();
363 148 start += static_cast<size_t>(max_keys_per_request_)) {
364 size_t end =
365 296 std::min(start + static_cast<size_t>(max_keys_per_request_),
366 148 shard_keys_vec.size());
367 base::ConstArray<uint64_t> shard_chunk(
368 148 shard_keys_vec.data() + start, static_cast<int>(end - start));
369 std::vector<std::vector<float>> value_chunk(
370
1/2
✓ Branch 6 taken 148 times.
✗ Branch 7 not taken.
148 shard_vals_vec.begin() + start, shard_vals_vec.begin() + end);
371
2/4
✓ Branch 1 taken 148 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 148 times.
148 if (client->PutParameter(shard_chunk, value_chunk) != 1) {
372 return 0;
373 }
374
1/2
✓ Branch 1 taken 148 times.
✗ Branch 2 not taken.
148 }
375 116 return 1;
376 }));
377 }
378
379 // 等待所有请求完成
380
2/2
✓ Branch 5 taken 116 times.
✓ Branch 6 taken 58 times.
174 for (auto& future : futures) {
381
2/4
✓ Branch 1 taken 116 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 116 times.
116 if (future.get() != 1) {
382 LOG(ERROR) << "Failed to put parameters to one of the shards";
383 return -1;
384 }
385 }
386
387 58 return 0;
388 58 }
389
390 void DistributedGRPCParameterClient::Command(PSCommand command) {
391 std::vector<std::future<void>> futures;
392
393 for (auto& client : clients_) {
394 futures.push_back(std::async(std::launch::async, [&client, command]() {
395 client->Command(command);
396 }));
397 }
398
399 for (auto& future : futures) {
400 future.wait();
401 }
402 }
403
404 14 bool DistributedGRPCParameterClient::ClearPS() {
405 14 std::vector<std::future<bool>> futures;
406
407
2/2
✓ Branch 4 taken 28 times.
✓ Branch 5 taken 14 times.
42 for (auto& client : clients_) {
408
2/4
✓ Branch 1 taken 28 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 28 times.
✗ Branch 5 not taken.
28 futures.push_back(std::async(std::launch::async, [&client]() {
409 28 return client->ClearPS();
410 }));
411 }
412
413 14 bool all_success = true;
414
2/2
✓ Branch 5 taken 28 times.
✓ Branch 6 taken 14 times.
42 for (auto& future : futures) {
415
2/4
✓ Branch 1 taken 28 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 28 times.
28 if (!future.get()) {
416 all_success = false;
417 }
418 }
419
420 14 return all_success;
421 14 }
422
423 2 bool DistributedGRPCParameterClient::LoadFakeData(int64_t n) {
424 2 std::vector<std::future<bool>> futures;
425
2/2
✓ Branch 4 taken 4 times.
✓ Branch 5 taken 2 times.
6 for (auto& client : clients_) {
426 4 GRPCParameterClient* raw = client.get();
427
2/4
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 5 not taken.
4 futures.push_back(std::async(std::launch::async, [raw, n]() {
428 4 return raw->LoadFakeData(n);
429 }));
430 }
431 2 bool all_success = true;
432
2/2
✓ Branch 5 taken 4 times.
✓ Branch 6 taken 2 times.
6 for (auto& future : futures) {
433
2/4
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
4 if (!future.get()) {
434 all_success = false;
435 }
436 }
437 2 return all_success;
438 2 }
439
440 2 bool DistributedGRPCParameterClient::DumpFakeData(int64_t n) {
441 2 std::vector<std::future<bool>> futures;
442
2/2
✓ Branch 4 taken 4 times.
✓ Branch 5 taken 2 times.
6 for (auto& client : clients_) {
443 4 GRPCParameterClient* raw = client.get();
444
2/4
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 5 not taken.
4 futures.push_back(std::async(std::launch::async, [raw, n]() {
445 4 return raw->DumpFakeData(n);
446 }));
447 }
448 2 bool all_success = true;
449
2/2
✓ Branch 5 taken 4 times.
✓ Branch 6 taken 2 times.
6 for (auto& future : futures) {
450
2/4
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
4 if (!future.get()) {
451 all_success = false;
452 }
453 }
454 2 return all_success;
455 2 }
456
457 bool DistributedGRPCParameterClient::LoadCkpt(
458 const std::vector<std::string>& model_config_path,
459 const std::vector<std::string>& emb_file_path) {
460 std::vector<std::future<bool>> futures;
461
462 for (auto& client : clients_) {
463 futures.push_back(std::async(
464 std::launch::async, [&client, &model_config_path, &emb_file_path]() {
465 return client->LoadCkpt(model_config_path, emb_file_path);
466 }));
467 }
468
469 bool all_success = true;
470 for (auto& future : futures) {
471 if (!future.get()) {
472 all_success = false;
473 }
474 }
475
476 return all_success;
477 }
478
479 2 int DistributedGRPCParameterClient::UpdateParameter(
480 const std::string& table_name,
481 const base::ConstArray<uint64_t>& keys,
482 const std::vector<std::vector<float>>* grads) {
483
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
2 if (grads == nullptr) {
484 LOG(ERROR) << "UpdateParameter grads pointer is null";
485 return -1;
486 }
487
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
2 if (keys.Size() != grads->size()) {
488 LOG(ERROR) << "UpdateParameter keys/grads size mismatch: " << keys.Size()
489 << " vs " << grads->size();
490 return -1;
491 }
492
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
2 if (keys.Size() == 0) {
493 return 0;
494 }
495
496 2 std::vector<std::vector<uint64_t>> partitioned_keys;
497
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 PartitionKeys(keys, partitioned_keys);
498
499
1/2
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
2 std::vector<std::vector<std::vector<float>>> partitioned_grads(num_shards_);
500
2/2
✓ Branch 0 taken 4 times.
✓ Branch 1 taken 2 times.
6 for (int shard_id = 0; shard_id < num_shards_; ++shard_id) {
501
2/2
✓ Branch 2 taken 4 times.
✓ Branch 3 taken 4 times.
8 for (size_t i = 0; i < key_index_mapping_[shard_id].size(); ++i) {
502 4 size_t original_index = key_index_mapping_[shard_id][i];
503
1/2
✓ Branch 3 taken 4 times.
✗ Branch 4 not taken.
4 partitioned_grads[shard_id].push_back((*grads)[original_index]);
504 }
505 }
506
507 2 std::vector<std::future<int>> futures;
508
2/2
✓ Branch 0 taken 4 times.
✓ Branch 1 taken 2 times.
6 for (int shard_id = 0; shard_id < num_shards_; ++shard_id) {
509
2/2
✓ Branch 2 taken 2 times.
✓ Branch 3 taken 2 times.
4 if (partitioned_keys[shard_id].empty()) {
510 2 continue;
511 }
512
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 auto it = shard_to_client_index_.find(shard_id);
513
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
2 if (it == shard_to_client_index_.end()) {
514 LOG(ERROR) << "No client found for shard " << shard_id;
515 return -1;
516 }
517 2 int client_index = it->second;
518 2 auto* client = clients_[client_index].get();
519
520
3/6
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 8 not taken.
2 futures.push_back(std::async(
521 4 std::launch::async, [=, &partitioned_keys, &partitioned_grads]() {
522 2 const auto& shard_keys_vec = partitioned_keys[shard_id];
523 2 const auto& shard_grads_vec = partitioned_grads[shard_id];
524
2/2
✓ Branch 1 taken 2 times.
✓ Branch 2 taken 2 times.
4 for (size_t start = 0; start < shard_keys_vec.size();
525 2 start += static_cast<size_t>(max_keys_per_request_)) {
526 size_t end =
527 4 std::min(start + static_cast<size_t>(max_keys_per_request_),
528 2 shard_keys_vec.size());
529 base::ConstArray<uint64_t> shard_chunk(
530 2 shard_keys_vec.data() + start, static_cast<int>(end - start));
531 std::vector<std::vector<float>> grad_chunk(
532
1/2
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
2 shard_grads_vec.begin() + start, shard_grads_vec.begin() + end);
533
2/4
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
2 if (client->UpdateParameter(table_name, shard_chunk, &grad_chunk) !=
534 0) {
535 return -1;
536 }
537
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 }
538 2 return 0;
539 }));
540 }
541
542
2/2
✓ Branch 5 taken 2 times.
✓ Branch 6 taken 2 times.
4 for (auto& future : futures) {
543
2/4
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
2 if (future.get() != 0) {
544 LOG(ERROR) << "Failed to update parameters on one of the shards";
545 return -1;
546 }
547 }
548
549 2 return 0;
550 2 }
551
552 2 int DistributedGRPCParameterClient::UpdateParameterFlat(
553 const std::string& table_name,
554 const base::ConstArray<uint64_t>& keys,
555 const float* grads,
556 int64_t num_rows,
557 int64_t embedding_dim) {
558
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
2 if (grads == nullptr) {
559 LOG(ERROR) << "UpdateParameterFlat grads pointer is null";
560 return -1;
561 }
562
2/4
✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
2 if (num_rows < 0 || embedding_dim <= 0) {
563 LOG(ERROR) << "UpdateParameterFlat invalid shape: rows=" << num_rows
564 << " dim=" << embedding_dim;
565 return -1;
566 }
567
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
2 if (keys.Size() != static_cast<size_t>(num_rows)) {
568 LOG(ERROR) << "UpdateParameterFlat keys/grads size mismatch: "
569 << keys.Size() << " vs " << num_rows;
570 return -1;
571 }
572
573 2 std::vector<std::vector<float>> row_grads;
574
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 row_grads.reserve(static_cast<size_t>(num_rows));
575
2/2
✓ Branch 0 taken 4 times.
✓ Branch 1 taken 2 times.
6 for (int64_t i = 0; i < num_rows; ++i) {
576 4 const float* row = grads + i * embedding_dim;
577
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 row_grads.emplace_back(row, row + embedding_dim);
578 }
579
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 return UpdateParameter(table_name, keys, &row_grads);
580 2 }
581
582 10 int DistributedGRPCParameterClient::InitEmbeddingTable(
583 const std::string& table_name,
584 const recstore::EmbeddingTableConfig& config) {
585 10 std::vector<std::future<int>> futures;
586
2/2
✓ Branch 4 taken 20 times.
✓ Branch 5 taken 10 times.
30 for (auto& client : clients_) {
587
1/2
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
20 futures.push_back(
588
1/2
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
40 std::async(std::launch::async, [&client, &table_name, &config]() {
589 20 return client->InitEmbeddingTable(table_name, config);
590 }));
591 }
592
593
2/2
✓ Branch 5 taken 20 times.
✓ Branch 6 taken 10 times.
30 for (auto& future : futures) {
594
2/4
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 20 times.
20 if (future.get() != 0) {
595 LOG(ERROR) << "InitEmbeddingTable failed on one of the shards";
596 return -1;
597 }
598 }
599
600 10 return 0;
601 10 }
602
603 // Prefetch 接口实现
604 38 uint64_t DistributedGRPCParameterClient::PrefetchParameter(
605 const base::ConstArray<uint64_t>& keys) {
606 auto cleanup_state = [this](const DistPrefetchState& state) {
607 for (const auto& shard_state : state.shard_states) {
608 if (shard_state.client_index < 0 ||
609 shard_state.client_index >= static_cast<int>(clients_.size())) {
610 continue;
611 }
612 auto* client = clients_[shard_state.client_index].get();
613 for (uint64_t child_prefetch_id : shard_state.child_prefetch_ids) {
614 client->WaitForPrefetch(child_prefetch_id);
615 std::vector<std::vector<float>> tmp;
616 client->GetPrefetchResult(child_prefetch_id, &tmp);
617 }
618 }
619 };
620
621
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 38 times.
38 if (keys.Size() == 0) {
622 std::lock_guard<std::mutex> lk(prefetch_mu_);
623 uint64_t prefetch_id = next_prefetch_id_++;
624 auto state = std::make_shared<DistPrefetchState>();
625 state->total_keys = 0;
626 prefetch_states_[prefetch_id] = state;
627 return prefetch_id;
628 }
629
630
1/2
✓ Branch 2 taken 38 times.
✗ Branch 3 not taken.
38 std::vector<std::vector<uint64_t>> shard_keys(num_shards_);
631
1/2
✓ Branch 2 taken 38 times.
✗ Branch 3 not taken.
38 std::vector<std::vector<size_t>> shard_indices(num_shards_);
632
2/2
✓ Branch 1 taken 424 times.
✓ Branch 2 taken 38 times.
462 for (size_t i = 0; i < keys.Size(); ++i) {
633
1/2
✓ Branch 2 taken 424 times.
✗ Branch 3 not taken.
424 const int shard_id = GetShardId(keys[i]);
634
1/2
✓ Branch 3 taken 424 times.
✗ Branch 4 not taken.
424 shard_keys[shard_id].push_back(keys[i]);
635
1/2
✓ Branch 2 taken 424 times.
✗ Branch 3 not taken.
424 shard_indices[shard_id].push_back(i);
636 }
637
638
1/2
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
38 auto state = std::make_shared<DistPrefetchState>();
639 38 state->total_keys = keys.Size();
640
641
2/2
✓ Branch 0 taken 76 times.
✓ Branch 1 taken 38 times.
114 for (int shard_id = 0; shard_id < num_shards_; ++shard_id) {
642
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 76 times.
76 if (shard_keys[shard_id].empty()) {
643 continue;
644 }
645
646
1/2
✓ Branch 1 taken 76 times.
✗ Branch 2 not taken.
76 auto it = shard_to_client_index_.find(shard_id);
647
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 76 times.
76 if (it == shard_to_client_index_.end()) {
648 LOG(ERROR) << "No client found for shard " << shard_id;
649 cleanup_state(*state);
650 return 0;
651 }
652
653 76 DistPrefetchShardState shard_state;
654 76 shard_state.shard_id = shard_id;
655 76 shard_state.client_index = it->second;
656 76 shard_state.original_indices = std::move(shard_indices[shard_id]);
657
658 76 const auto& skeys = shard_keys[shard_id];
659
2/2
✓ Branch 1 taken 104 times.
✓ Branch 2 taken 76 times.
180 for (size_t start = 0; start < skeys.size();
660 104 start += static_cast<size_t>(max_keys_per_request_)) {
661 104 size_t end = std::min(
662 104 start + static_cast<size_t>(max_keys_per_request_), skeys.size());
663 base::ConstArray<uint64_t> chunk(
664 104 skeys.data() + start, static_cast<int>(end - start));
665 uint64_t child_prefetch_id =
666
1/2
✓ Branch 3 taken 104 times.
✗ Branch 4 not taken.
104 clients_[shard_state.client_index]->PrefetchParameter(chunk);
667
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 104 times.
104 if (child_prefetch_id == 0) {
668 LOG(ERROR) << "PrefetchParameter failed for shard " << shard_id;
669 cleanup_state(*state);
670 return 0;
671 }
672
1/2
✓ Branch 1 taken 104 times.
✗ Branch 2 not taken.
104 shard_state.child_prefetch_ids.push_back(child_prefetch_id);
673
1/2
✓ Branch 1 taken 104 times.
✗ Branch 2 not taken.
104 shard_state.chunk_sizes.push_back(static_cast<int>(end - start));
674 }
675
1/2
✓ Branch 3 taken 76 times.
✗ Branch 4 not taken.
76 state->shard_states.push_back(std::move(shard_state));
676
1/2
✓ Branch 1 taken 76 times.
✗ Branch 2 not taken.
76 }
677
678
1/2
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
38 std::lock_guard<std::mutex> lk(prefetch_mu_);
679 38 uint64_t prefetch_id = next_prefetch_id_++;
680
1/2
✓ Branch 2 taken 38 times.
✗ Branch 3 not taken.
38 prefetch_states_[prefetch_id] = std::move(state);
681 38 return prefetch_id;
682 38 }
683
684 4 bool DistributedGRPCParameterClient::IsPrefetchDone(uint64_t prefetch_id) {
685 4 std::shared_ptr<DistPrefetchState> state;
686 {
687
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 std::lock_guard<std::mutex> lk(prefetch_mu_);
688
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 auto it = prefetch_states_.find(prefetch_id);
689
2/2
✓ Branch 2 taken 2 times.
✓ Branch 3 taken 2 times.
4 if (it == prefetch_states_.end()) {
690
4/8
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 11 not taken.
2 LOG(ERROR) << "Invalid prefetch_id: " << prefetch_id;
691 2 return false;
692 }
693 2 state = it->second;
694
2/2
✓ Branch 1 taken 2 times.
✓ Branch 2 taken 2 times.
4 }
695
696
2/2
✓ Branch 6 taken 4 times.
✓ Branch 7 taken 2 times.
6 for (const auto& shard_state : state->shard_states) {
697 4 auto* client = clients_[shard_state.client_index].get();
698
2/2
✓ Branch 5 taken 4 times.
✓ Branch 6 taken 4 times.
8 for (uint64_t child_prefetch_id : shard_state.child_prefetch_ids) {
699
2/4
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
4 if (!client->IsPrefetchDone(child_prefetch_id)) {
700 return false;
701 }
702 }
703 }
704 2 return true;
705 4 }
706
707 74 void DistributedGRPCParameterClient::WaitForPrefetch(uint64_t prefetch_id) {
708 74 std::shared_ptr<DistPrefetchState> state;
709 {
710
1/2
✓ Branch 1 taken 74 times.
✗ Branch 2 not taken.
74 std::lock_guard<std::mutex> lk(prefetch_mu_);
711
1/2
✓ Branch 1 taken 74 times.
✗ Branch 2 not taken.
74 auto it = prefetch_states_.find(prefetch_id);
712
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 74 times.
74 if (it == prefetch_states_.end()) {
713 LOG(ERROR) << "Invalid prefetch_id: " << prefetch_id;
714 return;
715 }
716 74 state = it->second;
717
1/2
✓ Branch 1 taken 74 times.
✗ Branch 2 not taken.
74 }
718
719
2/2
✓ Branch 6 taken 148 times.
✓ Branch 7 taken 74 times.
222 for (const auto& shard_state : state->shard_states) {
720 148 auto* client = clients_[shard_state.client_index].get();
721
2/2
✓ Branch 5 taken 204 times.
✓ Branch 6 taken 148 times.
352 for (uint64_t child_prefetch_id : shard_state.child_prefetch_ids) {
722
1/2
✓ Branch 1 taken 204 times.
✗ Branch 2 not taken.
204 client->WaitForPrefetch(child_prefetch_id);
723 }
724 }
725
1/2
✓ Branch 1 taken 74 times.
✗ Branch 2 not taken.
74 }
726
727 42 bool DistributedGRPCParameterClient::GetPrefetchResult(
728 uint64_t prefetch_id, std::vector<std::vector<float>>* values) {
729
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 42 times.
42 if (values == nullptr) {
730 LOG(ERROR) << "GetPrefetchResult output pointer is null";
731 return false;
732 }
733
734 42 std::shared_ptr<DistPrefetchState> state;
735 {
736
1/2
✓ Branch 1 taken 42 times.
✗ Branch 2 not taken.
42 std::lock_guard<std::mutex> lk(prefetch_mu_);
737
1/2
✓ Branch 1 taken 42 times.
✗ Branch 2 not taken.
42 auto it = prefetch_states_.find(prefetch_id);
738
2/2
✓ Branch 2 taken 4 times.
✓ Branch 3 taken 38 times.
42 if (it == prefetch_states_.end()) {
739
4/8
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 4 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 4 times.
✗ Branch 11 not taken.
4 LOG(ERROR) << "Invalid prefetch_id: " << prefetch_id;
740 4 return false;
741 }
742 38 state = it->second;
743
2/2
✓ Branch 1 taken 38 times.
✓ Branch 2 taken 4 times.
42 }
744
745 // Ensure all child RPCs are completed before consuming payloads.
746
1/2
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
38 WaitForPrefetch(prefetch_id);
747
748 38 values->clear();
749
1/2
✓ Branch 2 taken 38 times.
✗ Branch 3 not taken.
38 values->resize(state->total_keys);
750
751 38 bool ok_all = true;
752
2/2
✓ Branch 6 taken 76 times.
✓ Branch 7 taken 38 times.
114 for (const auto& shard_state : state->shard_states) {
753 76 auto* client = clients_[shard_state.client_index].get();
754 76 size_t shard_offset = 0;
755
2/2
✓ Branch 1 taken 104 times.
✓ Branch 2 taken 76 times.
180 for (size_t i = 0; i < shard_state.child_prefetch_ids.size(); ++i) {
756 104 std::vector<std::vector<float>> chunk_values;
757
2/4
✓ Branch 1 taken 104 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 104 times.
104 if (!client->GetPrefetchResult(
758 104 shard_state.child_prefetch_ids[i], &chunk_values)) {
759 ok_all = false;
760 break;
761 }
762 const int expected =
763 104 (i < shard_state.chunk_sizes.size()
764
1/2
✓ Branch 0 taken 104 times.
✗ Branch 1 not taken.
104 ? shard_state.chunk_sizes[i]
765 104 : -1);
766
3/6
✓ Branch 0 taken 104 times.
✗ Branch 1 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 104 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 104 times.
104 if (expected >= 0 && static_cast<int>(chunk_values.size()) != expected) {
767 LOG(ERROR) << "Prefetch chunk size mismatch: got "
768 << chunk_values.size() << ", expected " << expected;
769 ok_all = false;
770 break;
771 }
772
2/2
✓ Branch 5 taken 424 times.
✓ Branch 6 taken 104 times.
528 for (const auto& row : chunk_values) {
773
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 424 times.
424 if (shard_offset >= shard_state.original_indices.size()) {
774 LOG(ERROR) << "Prefetch result overflow in shard "
775 << shard_state.shard_id;
776 ok_all = false;
777 break;
778 }
779
1/2
✓ Branch 3 taken 424 times.
✗ Branch 4 not taken.
424 (*values)[shard_state.original_indices[shard_offset++]] = row;
780 }
781
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 104 times.
104 if (!ok_all) {
782 break;
783 }
784
1/2
✓ Branch 1 taken 104 times.
✗ Branch 2 not taken.
104 }
785
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 76 times.
76 if (!ok_all) {
786 break;
787 }
788 }
789
790 {
791
1/2
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
38 std::lock_guard<std::mutex> lk(prefetch_mu_);
792
1/2
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
38 prefetch_states_.erase(prefetch_id);
793 38 }
794 38 return ok_all;
795 42 }
796
797 6 bool DistributedGRPCParameterClient::GetPrefetchResultFlat(
798 uint64_t prefetch_id,
799 std::vector<float>* values,
800 int64_t* num_rows,
801 int64_t embedding_dim) {
802
2/4
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 6 times.
6 if (values == nullptr || num_rows == nullptr) {
803 LOG(ERROR) << "GetPrefetchResultFlat output pointer is null";
804 return false;
805 }
806
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
6 if (embedding_dim <= 0) {
807 LOG(ERROR) << "GetPrefetchResultFlat invalid embedding_dim: "
808 << embedding_dim;
809 return false;
810 }
811
812 6 std::vector<std::vector<float>> merged_values;
813
3/4
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
✓ Branch 4 taken 4 times.
6 if (!GetPrefetchResult(prefetch_id, &merged_values)) {
814 2 return false;
815 }
816
817 4 *num_rows = static_cast<int64_t>(merged_values.size());
818 4 values->assign(
819 4 static_cast<size_t>(*num_rows) * static_cast<size_t>(embedding_dim),
820
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 0.0f);
821
2/2
✓ Branch 1 taken 24 times.
✓ Branch 2 taken 4 times.
28 for (size_t i = 0; i < merged_values.size(); ++i) {
822 24 const auto& row = merged_values[i];
823
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 24 times.
24 if (row.empty()) {
824 continue;
825 }
826 const int64_t copy_d =
827 24 std::min<int64_t>(embedding_dim, static_cast<int64_t>(row.size()));
828 24 std::memcpy(values->data() + i * static_cast<size_t>(embedding_dim),
829 24 row.data(),
830 24 static_cast<size_t>(copy_d) * sizeof(float));
831 }
832 4 return true;
833 6 }
834
835 } // namespace recstore
836