GCC Code Coverage Report


Directory: src/
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 33.4% 172 / 0 / 515
Functions: 47.1% 16 / 0 / 34
Branches: 19.4% 161 / 0 / 832

ps/brpc/dist_brpc_ps_client.cpp
Line Branch Exec Source
1 #include "dist_brpc_ps_client.h"
2
3 #include <algorithm>
4 #include <cstring>
5 #include <fstream>
6 #include <future>
7 #include <thread>
8
9 #include "base/factory.h"
10 #include "base/flatc.h"
11 #include "base/log.h"
12 #include "base/timer.h"
13
14 #ifdef ENABLE_PERF_REPORT
15 # include <chrono>
16 # include "base/report/report_client.h"
17 #endif
18
19 using recstoreps_brpc::CommandRequest;
20 using recstoreps_brpc::CommandResponse;
21 using recstoreps_brpc::GetParameterRequest;
22 using recstoreps_brpc::GetParameterResponse;
23 using recstoreps_brpc::PSCommand;
24 using recstoreps_brpc::PutParameterRequest;
25 using recstoreps_brpc::PutParameterResponse;
26
27 namespace recstore {
28
29 FACTORY_REGISTER(
30 BasePSClient, distributed_brpc, DistributedBRPCParameterClient, json);
31
32 10 DistributedBRPCParameterClient::DistributedBRPCParameterClient(json config)
33
1/2
✓ Branch 2 taken 10 times.
✗ Branch 3 not taken.
10 : BasePSClient(config) {
34 10 json client_config;
35
36
2/4
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 10 times.
✗ Branch 4 not taken.
10 if (config.contains("distributed_client")) {
37
2/4
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 10 times.
✗ Branch 5 not taken.
20 LOG(INFO) << "Detected recstore config format, extracting "
38
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 "distributed_client section";
39
2/4
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 10 times.
✗ Branch 5 not taken.
10 client_config = config["distributed_client"];
40 } else {
41 LOG(FATAL)
42 << "Invalid config format. Expected either recstore config with "
43 "'distributed_client' section "
44 << "or direct client config with 'servers' and 'num_shards' fields";
45 }
46
47 // 验证必要字段
48
3/6
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 10 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 10 times.
20 if (!client_config.contains("servers") ||
49
2/4
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 10 times.
10 !client_config["servers"].is_array()) {
50 LOG(FATAL)
51 << "Missing or invalid 'servers' field in distributed client config";
52 }
53
54
3/6
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 10 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 10 times.
20 if (!client_config.contains("num_shards") ||
55
2/4
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 10 times.
10 !client_config["num_shards"].is_number_integer()) {
56 LOG(FATAL)
57 << "Missing or invalid 'num_shards' field in distributed client config";
58 }
59
60
2/4
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 10 times.
✗ Branch 5 not taken.
10 num_shards_ = client_config["num_shards"].get<int>();
61
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 max_keys_per_request_ = client_config.value("max_keys_per_request", 500);
62
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 hash_method_ = client_config.value("hash_method", "city_hash");
63
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 10 times.
10 if (max_keys_per_request_ <= 0) {
64 LOG(FATAL) << "Invalid max_keys_per_request: " << max_keys_per_request_
65 << ", must be > 0";
66 }
67
68 // 解析服务器配置
69
2/4
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 10 times.
✗ Branch 5 not taken.
10 auto servers = client_config["servers"];
70
1/2
✓ Branch 2 taken 10 times.
✗ Branch 3 not taken.
10 server_configs_.reserve(servers.size());
71
72
2/2
✓ Branch 1 taken 20 times.
✓ Branch 2 taken 10 times.
30 for (size_t i = 0; i < servers.size(); ++i) {
73
1/2
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
20 const auto& server = servers[i];
74
5/10
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 20 times.
✗ Branch 4 not taken.
✓ Branch 6 taken 20 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 20 times.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✓ Branch 11 taken 20 times.
40 if (!server.contains("host") || !server.contains("port") ||
75
2/4
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 20 times.
20 !server.contains("shard")) {
76 LOG(FATAL) << "Server config " << i
77 << " missing required fields (host, port, shard)";
78 }
79
80 20 ServerConfig cfg;
81
2/4
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 20 times.
✗ Branch 5 not taken.
20 cfg.host = server["host"].get<std::string>();
82
2/4
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 20 times.
✗ Branch 5 not taken.
20 cfg.port = server["port"].get<int>();
83
2/4
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 20 times.
✗ Branch 5 not taken.
20 cfg.shard = server["shard"].get<int>();
84
85
1/2
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
20 server_configs_.push_back(cfg);
86
1/2
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
20 shard_to_client_index_[cfg.shard] = i;
87 20 }
88
89
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 10 times.
10 if (server_configs_.size() != static_cast<size_t>(num_shards_)) {
90 LOG(WARNING) << "Number of servers (" << server_configs_.size()
91 << ") doesn't match num_shards (" << num_shards_ << ")";
92 }
93
94
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 partitioned_key_buffer_.resize(num_shards_);
95
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 key_index_mapping_.resize(num_shards_);
96
97
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 InitializeClients();
98
99
3/6
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 10 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 10 times.
✗ Branch 8 not taken.
20 LOG(INFO) << "Initialized DistributedBRPCParameterClient with " << num_shards_
100
3/6
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 10 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 10 times.
✗ Branch 8 not taken.
10 << " shards, hash method: " << hash_method_;
101 10 }
102
103 14 DistributedBRPCParameterClient::~DistributedBRPCParameterClient() {}
104
105 10 void DistributedBRPCParameterClient::InitializeClients() {
106 10 clients_.clear();
107 10 clients_.reserve(server_configs_.size());
108
109
2/2
✓ Branch 5 taken 20 times.
✓ Branch 6 taken 10 times.
30 for (const auto& server_config : server_configs_) {
110 // 为每个服务器创建独立的配置
111 20 json client_config = {{"host", server_config.host},
112 20 {"port", server_config.port},
113
8/16
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 20 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 20 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 20 times.
✗ Branch 11 not taken.
✓ Branch 14 taken 20 times.
✗ Branch 15 not taken.
✓ Branch 17 taken 20 times.
✗ Branch 18 not taken.
✓ Branch 21 taken 20 times.
✗ Branch 22 not taken.
✓ Branch 24 taken 20 times.
✗ Branch 25 not taken.
260 {"shard", server_config.shard}};
114
115
1/2
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
20 auto client = std::make_unique<BRPCParameterClient>(client_config);
116
1/2
✓ Branch 2 taken 20 times.
✗ Branch 3 not taken.
20 clients_.push_back(std::move(client));
117
118
3/6
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 20 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 20 times.
✗ Branch 8 not taken.
40 LOG(INFO) << "Created bRPC client for shard " << server_config.shard
119
5/10
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 20 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 20 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 20 times.
✗ Branch 11 not taken.
✓ Branch 13 taken 20 times.
✗ Branch 14 not taken.
20 << " at " << server_config.host << ":" << server_config.port;
120 20 }
121 10 }
122
123 448 int DistributedBRPCParameterClient::GetShardId(uint64_t key) const {
124
1/2
✓ Branch 1 taken 448 times.
✗ Branch 2 not taken.
448 if (hash_method_ == "city_hash") {
125 448 return GetHash(key) % num_shards_;
126 } else if (hash_method_ == "simple_mod") {
127 return key % num_shards_;
128 } else {
129 LOG(ERROR) << "Unknown hash method: " << hash_method_
130 << ", using city_hash";
131 return GetHash(key) % num_shards_;
132 }
133 }
134
135 20 void DistributedBRPCParameterClient::PartitionKeys(
136 const base::ConstArray<uint64_t>& keys,
137 std::vector<std::vector<uint64_t>>& partitioned_keys) const {
138
2/2
✓ Branch 5 taken 40 times.
✓ Branch 6 taken 20 times.
60 for (auto& partition : partitioned_key_buffer_) {
139 40 partition.clear();
140 }
141
2/2
✓ Branch 5 taken 40 times.
✓ Branch 6 taken 20 times.
60 for (auto& mapping : key_index_mapping_) {
142 40 mapping.clear();
143 }
144
145
2/2
✓ Branch 1 taken 448 times.
✓ Branch 2 taken 20 times.
468 for (size_t i = 0; i < keys.Size(); ++i) {
146 448 uint64_t key = keys[i];
147
1/2
✓ Branch 1 taken 448 times.
✗ Branch 2 not taken.
448 int shard_id = GetShardId(key);
148
149
1/2
✓ Branch 2 taken 448 times.
✗ Branch 3 not taken.
448 partitioned_key_buffer_[shard_id].push_back(key);
150
1/2
✓ Branch 2 taken 448 times.
✗ Branch 3 not taken.
448 key_index_mapping_[shard_id].push_back(i);
151 }
152
153 20 partitioned_keys = partitioned_key_buffer_;
154 20 }
155
156 14 bool DistributedBRPCParameterClient::GetParameter(
157 const base::ConstArray<uint64_t>& keys,
158 std::vector<std::vector<float>>* values) {
159 #ifdef ENABLE_PERF_REPORT
160 auto start_time = std::chrono::high_resolution_clock::now();
161 #endif
162
163
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 14 times.
14 if (keys.Size() == 0) {
164 values->clear();
165 return true;
166 }
167
168
2/4
✓ Branch 2 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 14 times.
✗ Branch 6 not taken.
28 xmh::Timer timer("DistributedBRPCParameterClient::GetParameter");
169
170 14 std::vector<std::vector<uint64_t>> partitioned_keys;
171
1/2
✓ Branch 1 taken 14 times.
✗ Branch 2 not taken.
14 PartitionKeys(keys, partitioned_keys);
172
173 14 std::vector<std::future<int>> futures;
174
1/2
✓ Branch 2 taken 14 times.
✗ Branch 3 not taken.
14 std::vector<std::vector<std::vector<float>>> partitioned_results(num_shards_);
175
176
2/2
✓ Branch 0 taken 28 times.
✓ Branch 1 taken 14 times.
42 for (int shard_id = 0; shard_id < num_shards_; ++shard_id) {
177
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 28 times.
28 if (partitioned_keys[shard_id].empty()) {
178 continue;
179 }
180
181
1/2
✓ Branch 1 taken 28 times.
✗ Branch 2 not taken.
28 auto it = shard_to_client_index_.find(shard_id);
182
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 28 times.
28 if (it == shard_to_client_index_.end()) {
183 LOG(ERROR) << "No client found for shard " << shard_id;
184 return false;
185 }
186
187 28 int client_index = it->second;
188 28 auto* client = clients_[client_index].get();
189
190 // 异步请求
191
2/4
✓ Branch 1 taken 28 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 28 times.
✗ Branch 5 not taken.
28 futures.push_back(std::async(
192 56 std::launch::async, [=, &partitioned_keys, &partitioned_results]() {
193 28 const auto& shard_keys_vec = partitioned_keys[shard_id];
194 28 auto& shard_result_vec = partitioned_results[shard_id];
195 28 shard_result_vec.clear();
196 28 shard_result_vec.reserve(shard_keys_vec.size());
197
198
2/2
✓ Branch 1 taken 30 times.
✓ Branch 2 taken 28 times.
58 for (size_t start = 0; start < shard_keys_vec.size();
199 30 start += static_cast<size_t>(max_keys_per_request_)) {
200 size_t end =
201 60 std::min(start + static_cast<size_t>(max_keys_per_request_),
202 30 shard_keys_vec.size());
203 base::ConstArray<uint64_t> shard_chunk(
204 30 shard_keys_vec.data() + start, static_cast<int>(end - start));
205 30 std::vector<std::vector<float>> chunk_result;
206
2/4
✓ Branch 1 taken 30 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 30 times.
30 if (!client->GetParameter(shard_chunk, &chunk_result)) {
207 return 0;
208 }
209
1/2
✓ Branch 5 taken 30 times.
✗ Branch 6 not taken.
30 shard_result_vec.insert(shard_result_vec.end(),
210 chunk_result.begin(),
211 chunk_result.end());
212
1/2
✓ Branch 1 taken 30 times.
✗ Branch 2 not taken.
30 }
213 28 return 1;
214 }));
215 }
216
217 // 等待所有请求完成
218
2/2
✓ Branch 5 taken 28 times.
✓ Branch 6 taken 14 times.
42 for (auto& future : futures) {
219
2/4
✓ Branch 1 taken 28 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 28 times.
28 if (!future.get()) {
220 LOG(ERROR) << "Failed to get parameters from one of the shards";
221 return false;
222 }
223 }
224
225 // 合并结果
226
1/2
✓ Branch 1 taken 14 times.
✗ Branch 2 not taken.
14 MergeResults(keys, partitioned_results, values);
227
228 #ifdef ENABLE_PERF_REPORT
229 auto end_time = std::chrono::high_resolution_clock::now();
230 auto duration =
231 std::chrono::duration_cast<std::chrono::microseconds>(
232 end_time - start_time)
233 .count();
234 double start_us =
235 std::chrono::duration_cast<std::chrono::microseconds>(
236 start_time.time_since_epoch())
237 .count();
238 FlameGraphData fg_data = {
239 "dist_client::GetParameter",
240 start_us,
241 1, // level
242 static_cast<double>(duration),
243 static_cast<double>(duration)};
244 std::string unique_id =
245 "embread_debug" + std::to_string(recstore::g_trace_id);
246 report_flame_graph("emb_read_flame_map", unique_id.c_str(), fg_data);
247 #endif
248
249 14 return true;
250 14 }
251
252 14 void DistributedBRPCParameterClient::MergeResults(
253 const base::ConstArray<uint64_t>& keys,
254 const std::vector<std::vector<std::vector<float>>>& partitioned_results,
255 std::vector<std::vector<float>>* values) const {
256 14 values->clear();
257 14 values->resize(keys.Size());
258
259 // 重建 key -> index 映射
260
2/2
✓ Branch 0 taken 28 times.
✓ Branch 1 taken 14 times.
42 for (int shard_id = 0; shard_id < num_shards_; ++shard_id) {
261
2/2
✓ Branch 2 taken 236 times.
✓ Branch 3 taken 28 times.
264 for (size_t i = 0; i < key_index_mapping_[shard_id].size(); ++i) {
262 236 size_t original_index = key_index_mapping_[shard_id][i];
263
1/2
✓ Branch 2 taken 236 times.
✗ Branch 3 not taken.
236 if (i < partitioned_results[shard_id].size()) {
264 236 (*values)[original_index] = partitioned_results[shard_id][i];
265 }
266 }
267 }
268 14 }
269
270 void DistributedBRPCParameterClient::MergeResultsToArray(
271 const base::ConstArray<uint64_t>& keys,
272 const std::vector<std::vector<std::vector<float>>>& partitioned_results,
273 float* values) const {
274 int emb_dim = 0;
275 for (int shard_id = 0; shard_id < num_shards_; ++shard_id) {
276 if (!partitioned_results[shard_id].empty() &&
277 !partitioned_results[shard_id][0].empty()) {
278 emb_dim = partitioned_results[shard_id][0].size();
279 break;
280 }
281 }
282
283 if (emb_dim == 0) {
284 LOG(WARNING) << "No valid embeddings found";
285 return;
286 }
287
288 // 合并结果到连续内存
289 for (int shard_id = 0; shard_id < num_shards_; ++shard_id) {
290 for (size_t i = 0; i < key_index_mapping_[shard_id].size(); ++i) {
291 size_t original_index = key_index_mapping_[shard_id][i];
292 if (i < partitioned_results[shard_id].size()) {
293 const auto& embedding = partitioned_results[shard_id][i];
294 std::copy(embedding.begin(),
295 embedding.end(),
296 values + original_index * emb_dim);
297 }
298 }
299 }
300 }
301
302 // 实现 BasePSClient 接口
303 int DistributedBRPCParameterClient::GetParameter(
304 const base::ConstArray<uint64_t>& keys, float* values) {
305 std::vector<std::vector<float>> result_vectors;
306 bool success = GetParameter(keys, &result_vectors);
307
308 if (!success) {
309 return -1;
310 }
311
312 if (keys.Size() == 0) {
313 return 0;
314 }
315 int emb_dim = 0;
316 for (const auto& row : result_vectors) {
317 if (!row.empty()) {
318 emb_dim = static_cast<int>(row.size());
319 break;
320 }
321 }
322 if (emb_dim == 0) {
323 LOG(WARNING) << "No valid embeddings found";
324 return 0;
325 }
326
327 for (size_t i = 0; i < result_vectors.size(); ++i) {
328 const auto& row = result_vectors[i];
329 if (row.empty()) {
330 continue;
331 }
332 std::copy(row.begin(), row.end(), values + i * emb_dim);
333 }
334 return 0;
335 }
336
337 int DistributedBRPCParameterClient::AsyncGetParameter(
338 const base::ConstArray<uint64_t>& keys, float* values) {
339 return GetParameter(keys, values);
340 }
341
342 6 int DistributedBRPCParameterClient::PutParameter(
343 const base::ConstArray<uint64_t>& keys,
344 const std::vector<std::vector<float>>& values) {
345
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 6 times.
6 if (keys.Size() != values.size()) {
346 LOG(ERROR) << "Keys and values size mismatch: " << keys.Size() << " vs "
347 << values.size();
348 return -1;
349 }
350
351 6 std::vector<std::vector<uint64_t>> partitioned_keys;
352
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 PartitionKeys(keys, partitioned_keys);
353
354
1/2
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
6 std::vector<std::vector<std::vector<float>>> partitioned_values(num_shards_);
355
2/2
✓ Branch 0 taken 12 times.
✓ Branch 1 taken 6 times.
18 for (int shard_id = 0; shard_id < num_shards_; ++shard_id) {
356
2/2
✓ Branch 2 taken 212 times.
✓ Branch 3 taken 12 times.
224 for (size_t i = 0; i < key_index_mapping_[shard_id].size(); ++i) {
357 212 size_t original_index = key_index_mapping_[shard_id][i];
358
1/2
✓ Branch 3 taken 212 times.
✗ Branch 4 not taken.
212 partitioned_values[shard_id].push_back(values[original_index]);
359 }
360 }
361
362 // 并发 put 到各个分片
363 6 std::vector<std::future<int>> futures;
364
365
2/2
✓ Branch 0 taken 12 times.
✓ Branch 1 taken 6 times.
18 for (int shard_id = 0; shard_id < num_shards_; ++shard_id) {
366
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 12 times.
12 if (partitioned_keys[shard_id].empty()) {
367 continue;
368 }
369
370
1/2
✓ Branch 1 taken 12 times.
✗ Branch 2 not taken.
12 auto it = shard_to_client_index_.find(shard_id);
371
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 12 times.
12 if (it == shard_to_client_index_.end()) {
372 LOG(ERROR) << "No client found for shard " << shard_id;
373 return -1;
374 }
375
376 12 int client_index = it->second;
377 12 auto* client = clients_[client_index].get();
378
379
2/4
✓ Branch 1 taken 12 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 12 times.
✗ Branch 5 not taken.
12 futures.push_back(std::async(
380 24 std::launch::async, [=, &partitioned_keys, &partitioned_values]() {
381 12 const auto& shard_keys_vec = partitioned_keys[shard_id];
382 12 const auto& shard_vals_vec = partitioned_values[shard_id];
383
2/2
✓ Branch 1 taken 14 times.
✓ Branch 2 taken 12 times.
26 for (size_t start = 0; start < shard_keys_vec.size();
384 14 start += static_cast<size_t>(max_keys_per_request_)) {
385 size_t end =
386 28 std::min(start + static_cast<size_t>(max_keys_per_request_),
387 14 shard_keys_vec.size());
388 base::ConstArray<uint64_t> shard_chunk(
389 14 shard_keys_vec.data() + start, static_cast<int>(end - start));
390 std::vector<std::vector<float>> value_chunk(
391
1/2
✓ Branch 6 taken 14 times.
✗ Branch 7 not taken.
14 shard_vals_vec.begin() + start, shard_vals_vec.begin() + end);
392
2/4
✓ Branch 1 taken 14 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
14 if (client->PutParameter(shard_chunk, value_chunk) != 1) {
393 return 0;
394 }
395
1/2
✓ Branch 1 taken 14 times.
✗ Branch 2 not taken.
14 }
396 12 return 1;
397 }));
398 }
399
400 // 等待所有请求完成
401
2/2
✓ Branch 5 taken 12 times.
✓ Branch 6 taken 6 times.
18 for (auto& future : futures) {
402
2/4
✓ Branch 1 taken 12 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
12 if (future.get() != 1) {
403 LOG(ERROR) << "Failed to put parameters to one of the shards";
404 return -1;
405 }
406 }
407
408 6 return 0;
409 6 }
410
411 void DistributedBRPCParameterClient::Command(PSCommand command) {
412 std::vector<std::future<void>> futures;
413
414 for (auto& client : clients_) {
415 futures.push_back(std::async(std::launch::async, [&client, command]() {
416 client->Command(command);
417 }));
418 }
419
420 for (auto& future : futures) {
421 future.wait();
422 }
423 }
424
425 10 bool DistributedBRPCParameterClient::ClearPS() {
426 10 std::vector<std::future<bool>> futures;
427
428
2/2
✓ Branch 4 taken 20 times.
✓ Branch 5 taken 10 times.
30 for (auto& client : clients_) {
429
2/4
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 20 times.
✗ Branch 5 not taken.
20 futures.push_back(std::async(std::launch::async, [&client]() {
430 20 return client->ClearPS();
431 }));
432 }
433
434 10 bool all_success = true;
435
2/2
✓ Branch 5 taken 20 times.
✓ Branch 6 taken 10 times.
30 for (auto& future : futures) {
436
2/4
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 20 times.
20 if (!future.get()) {
437 all_success = false;
438 }
439 }
440
441 10 return all_success;
442 10 }
443
444 2 bool DistributedBRPCParameterClient::LoadFakeData(int64_t n) {
445 2 std::vector<std::future<bool>> futures;
446
2/2
✓ Branch 4 taken 4 times.
✓ Branch 5 taken 2 times.
6 for (auto& client : clients_) {
447 4 BRPCParameterClient* raw = client.get();
448
2/4
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 5 not taken.
4 futures.push_back(std::async(std::launch::async, [raw, n]() {
449 4 return raw->LoadFakeData(n);
450 }));
451 }
452 2 bool all_success = true;
453
2/2
✓ Branch 5 taken 4 times.
✓ Branch 6 taken 2 times.
6 for (auto& future : futures) {
454
2/4
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
4 if (!future.get()) {
455 all_success = false;
456 }
457 }
458 2 return all_success;
459 2 }
460
461 2 bool DistributedBRPCParameterClient::DumpFakeData(int64_t n) {
462 2 std::vector<std::future<bool>> futures;
463
2/2
✓ Branch 4 taken 4 times.
✓ Branch 5 taken 2 times.
6 for (auto& client : clients_) {
464 4 BRPCParameterClient* raw = client.get();
465
2/4
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 5 not taken.
4 futures.push_back(std::async(std::launch::async, [raw, n]() {
466 4 return raw->DumpFakeData(n);
467 }));
468 }
469 2 bool all_success = true;
470
2/2
✓ Branch 5 taken 4 times.
✓ Branch 6 taken 2 times.
6 for (auto& future : futures) {
471
2/4
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
4 if (!future.get()) {
472 all_success = false;
473 }
474 }
475 2 return all_success;
476 2 }
477 bool DistributedBRPCParameterClient::LoadCkpt(
478 const std::vector<std::string>& model_config_path,
479 const std::vector<std::string>& emb_file_path) {
480 std::vector<std::future<bool>> futures;
481
482 for (auto& client : clients_) {
483 futures.push_back(std::async(
484 std::launch::async, [&client, &model_config_path, &emb_file_path]() {
485 return client->LoadCkpt(model_config_path, emb_file_path);
486 }));
487 }
488
489 bool all_success = true;
490 for (auto& future : futures) {
491 if (!future.get()) {
492 all_success = false;
493 }
494 }
495
496 return all_success;
497 }
498
499 int DistributedBRPCParameterClient::UpdateParameter(
500 const std::string& table_name,
501 const base::ConstArray<uint64_t>& keys,
502 const std::vector<std::vector<float>>* grads) {
503 if (grads == nullptr) {
504 LOG(ERROR) << "UpdateParameter grads pointer is null";
505 return -1;
506 }
507 if (keys.Size() != grads->size()) {
508 LOG(ERROR) << "UpdateParameter keys/grads size mismatch: " << keys.Size()
509 << " vs " << grads->size();
510 return -1;
511 }
512 if (keys.Size() == 0) {
513 return 0;
514 }
515
516 std::vector<std::vector<uint64_t>> partitioned_keys;
517 PartitionKeys(keys, partitioned_keys);
518
519 std::vector<std::vector<std::vector<float>>> partitioned_grads(num_shards_);
520 for (int shard_id = 0; shard_id < num_shards_; ++shard_id) {
521 for (size_t i = 0; i < key_index_mapping_[shard_id].size(); ++i) {
522 size_t original_index = key_index_mapping_[shard_id][i];
523 partitioned_grads[shard_id].push_back((*grads)[original_index]);
524 }
525 }
526
527 std::vector<std::future<int>> futures;
528 for (int shard_id = 0; shard_id < num_shards_; ++shard_id) {
529 if (partitioned_keys[shard_id].empty()) {
530 continue;
531 }
532
533 auto it = shard_to_client_index_.find(shard_id);
534 if (it == shard_to_client_index_.end()) {
535 LOG(ERROR) << "No client found for shard " << shard_id;
536 return -1;
537 }
538 int client_index = it->second;
539 auto* client = clients_[client_index].get();
540
541 futures.push_back(std::async(
542 std::launch::async, [=, &partitioned_keys, &partitioned_grads]() {
543 const auto& shard_keys_vec = partitioned_keys[shard_id];
544 const auto& shard_grads_vec = partitioned_grads[shard_id];
545 for (size_t start = 0; start < shard_keys_vec.size();
546 start += static_cast<size_t>(max_keys_per_request_)) {
547 size_t end =
548 std::min(start + static_cast<size_t>(max_keys_per_request_),
549 shard_keys_vec.size());
550 base::ConstArray<uint64_t> shard_chunk(
551 shard_keys_vec.data() + start, static_cast<int>(end - start));
552 std::vector<std::vector<float>> grad_chunk(
553 shard_grads_vec.begin() + start, shard_grads_vec.begin() + end);
554 if (client->UpdateParameter(table_name, shard_chunk, &grad_chunk) !=
555 0) {
556 return -1;
557 }
558 }
559 return 0;
560 }));
561 }
562
563 for (auto& future : futures) {
564 if (future.get() != 0) {
565 LOG(ERROR) << "Failed to update parameters on one of the shards";
566 return -1;
567 }
568 }
569
570 return 0;
571 }
572
573 int DistributedBRPCParameterClient::UpdateParameterFlat(
574 const std::string& table_name,
575 const base::ConstArray<uint64_t>& keys,
576 const float* grads,
577 int64_t num_rows,
578 int64_t embedding_dim) {
579 if (grads == nullptr) {
580 LOG(ERROR) << "UpdateParameterFlat grads pointer is null";
581 return -1;
582 }
583 if (num_rows < 0 || embedding_dim <= 0) {
584 LOG(ERROR) << "UpdateParameterFlat invalid shape: rows=" << num_rows
585 << " dim=" << embedding_dim;
586 return -1;
587 }
588 if (keys.Size() != static_cast<size_t>(num_rows)) {
589 LOG(ERROR) << "UpdateParameterFlat keys/grads size mismatch: "
590 << keys.Size() << " vs " << num_rows;
591 return -1;
592 }
593
594 std::vector<std::vector<float>> row_grads;
595 row_grads.reserve(static_cast<size_t>(num_rows));
596 for (int64_t i = 0; i < num_rows; ++i) {
597 const float* row = grads + i * embedding_dim;
598 row_grads.emplace_back(row, row + embedding_dim);
599 }
600 return UpdateParameter(table_name, keys, &row_grads);
601 }
602
603 int DistributedBRPCParameterClient::InitEmbeddingTable(
604 const std::string& table_name,
605 const recstore::EmbeddingTableConfig& config) {
606 std::vector<std::future<int>> futures;
607 for (auto& client : clients_) {
608 futures.push_back(
609 std::async(std::launch::async, [&client, &table_name, &config]() {
610 return client->InitEmbeddingTable(table_name, config);
611 }));
612 }
613
614 for (auto& future : futures) {
615 if (future.get() != 0) {
616 LOG(ERROR) << "InitEmbeddingTable failed on one of the shards";
617 return -1;
618 }
619 }
620 return 0;
621 }
622
623 // Prefetch 接口实现
624 uint64_t DistributedBRPCParameterClient::PrefetchParameter(
625 const base::ConstArray<uint64_t>& keys) {
626 auto cleanup_state = [this](const DistPrefetchState& state) {
627 for (const auto& shard_state : state.shard_states) {
628 if (shard_state.client_index < 0 ||
629 shard_state.client_index >= static_cast<int>(clients_.size())) {
630 continue;
631 }
632 auto* client = clients_[shard_state.client_index].get();
633 for (uint64_t child_prefetch_id : shard_state.child_prefetch_ids) {
634 client->WaitForPrefetch(child_prefetch_id);
635 std::vector<std::vector<float>> tmp;
636 client->GetPrefetchResult(child_prefetch_id, &tmp);
637 }
638 }
639 };
640
641 if (keys.Size() == 0) {
642 std::lock_guard<std::mutex> lk(prefetch_mu_);
643 uint64_t prefetch_id = next_prefetch_id_++;
644 auto state = std::make_shared<DistPrefetchState>();
645 state->total_keys = 0;
646 prefetch_states_[prefetch_id] = state;
647 return prefetch_id;
648 }
649
650 std::vector<std::vector<uint64_t>> shard_keys(num_shards_);
651 std::vector<std::vector<size_t>> shard_indices(num_shards_);
652 for (size_t i = 0; i < keys.Size(); ++i) {
653 const int shard_id = GetShardId(keys[i]);
654 shard_keys[shard_id].push_back(keys[i]);
655 shard_indices[shard_id].push_back(i);
656 }
657
658 auto state = std::make_shared<DistPrefetchState>();
659 state->total_keys = keys.Size();
660
661 for (int shard_id = 0; shard_id < num_shards_; ++shard_id) {
662 if (shard_keys[shard_id].empty()) {
663 continue;
664 }
665
666 auto it = shard_to_client_index_.find(shard_id);
667 if (it == shard_to_client_index_.end()) {
668 LOG(ERROR) << "No client found for shard " << shard_id;
669 cleanup_state(*state);
670 return 0;
671 }
672
673 DistPrefetchShardState shard_state;
674 shard_state.shard_id = shard_id;
675 shard_state.client_index = it->second;
676 shard_state.original_indices = std::move(shard_indices[shard_id]);
677
678 const auto& skeys = shard_keys[shard_id];
679 for (size_t start = 0; start < skeys.size();
680 start += static_cast<size_t>(max_keys_per_request_)) {
681 size_t end = std::min(
682 start + static_cast<size_t>(max_keys_per_request_), skeys.size());
683 base::ConstArray<uint64_t> chunk(
684 skeys.data() + start, static_cast<int>(end - start));
685 uint64_t child_prefetch_id =
686 clients_[shard_state.client_index]->PrefetchParameter(chunk);
687 if (child_prefetch_id == 0) {
688 LOG(ERROR) << "PrefetchParameter failed for shard " << shard_id;
689 cleanup_state(*state);
690 return 0;
691 }
692 shard_state.child_prefetch_ids.push_back(child_prefetch_id);
693 shard_state.chunk_sizes.push_back(static_cast<int>(end - start));
694 }
695 state->shard_states.push_back(std::move(shard_state));
696 }
697
698 std::lock_guard<std::mutex> lk(prefetch_mu_);
699 uint64_t prefetch_id = next_prefetch_id_++;
700 prefetch_states_[prefetch_id] = std::move(state);
701 return prefetch_id;
702 }
703
704 bool DistributedBRPCParameterClient::IsPrefetchDone(uint64_t prefetch_id) {
705 std::shared_ptr<DistPrefetchState> state;
706 {
707 std::lock_guard<std::mutex> lk(prefetch_mu_);
708 auto it = prefetch_states_.find(prefetch_id);
709 if (it == prefetch_states_.end()) {
710 LOG(ERROR) << "Invalid prefetch_id: " << prefetch_id;
711 return false;
712 }
713 state = it->second;
714 }
715
716 for (const auto& shard_state : state->shard_states) {
717 auto* client = clients_[shard_state.client_index].get();
718 for (uint64_t child_prefetch_id : shard_state.child_prefetch_ids) {
719 if (!client->IsPrefetchDone(child_prefetch_id)) {
720 return false;
721 }
722 }
723 }
724 return true;
725 }
726
727 void DistributedBRPCParameterClient::WaitForPrefetch(uint64_t prefetch_id) {
728 std::shared_ptr<DistPrefetchState> state;
729 {
730 std::lock_guard<std::mutex> lk(prefetch_mu_);
731 auto it = prefetch_states_.find(prefetch_id);
732 if (it == prefetch_states_.end()) {
733 LOG(ERROR) << "Invalid prefetch_id: " << prefetch_id;
734 return;
735 }
736 state = it->second;
737 }
738
739 for (const auto& shard_state : state->shard_states) {
740 auto* client = clients_[shard_state.client_index].get();
741 for (uint64_t child_prefetch_id : shard_state.child_prefetch_ids) {
742 client->WaitForPrefetch(child_prefetch_id);
743 }
744 }
745 }
746
747 bool DistributedBRPCParameterClient::GetPrefetchResult(
748 uint64_t prefetch_id, std::vector<std::vector<float>>* values) {
749 if (values == nullptr) {
750 LOG(ERROR) << "GetPrefetchResult output pointer is null";
751 return false;
752 }
753
754 std::shared_ptr<DistPrefetchState> state;
755 {
756 std::lock_guard<std::mutex> lk(prefetch_mu_);
757 auto it = prefetch_states_.find(prefetch_id);
758 if (it == prefetch_states_.end()) {
759 LOG(ERROR) << "Invalid prefetch_id: " << prefetch_id;
760 return false;
761 }
762 state = it->second;
763 }
764
765 // Ensure all child RPCs are completed before consuming payloads.
766 WaitForPrefetch(prefetch_id);
767
768 values->clear();
769 values->resize(state->total_keys);
770
771 bool ok_all = true;
772 for (const auto& shard_state : state->shard_states) {
773 auto* client = clients_[shard_state.client_index].get();
774 size_t shard_offset = 0;
775 for (size_t i = 0; i < shard_state.child_prefetch_ids.size(); ++i) {
776 std::vector<std::vector<float>> chunk_values;
777 if (!client->GetPrefetchResult(
778 shard_state.child_prefetch_ids[i], &chunk_values)) {
779 ok_all = false;
780 break;
781 }
782 const int expected =
783 (i < shard_state.chunk_sizes.size()
784 ? shard_state.chunk_sizes[i]
785 : -1);
786 if (expected >= 0 && static_cast<int>(chunk_values.size()) != expected) {
787 LOG(ERROR) << "Prefetch chunk size mismatch: got "
788 << chunk_values.size() << ", expected " << expected;
789 ok_all = false;
790 break;
791 }
792 for (const auto& row : chunk_values) {
793 if (shard_offset >= shard_state.original_indices.size()) {
794 LOG(ERROR) << "Prefetch result overflow in shard "
795 << shard_state.shard_id;
796 ok_all = false;
797 break;
798 }
799 (*values)[shard_state.original_indices[shard_offset++]] = row;
800 }
801 if (!ok_all) {
802 break;
803 }
804 }
805 if (!ok_all) {
806 break;
807 }
808 }
809
810 {
811 std::lock_guard<std::mutex> lk(prefetch_mu_);
812 prefetch_states_.erase(prefetch_id);
813 }
814 return ok_all;
815 }
816
817 bool DistributedBRPCParameterClient::GetPrefetchResultFlat(
818 uint64_t prefetch_id,
819 std::vector<float>* values,
820 int64_t* num_rows,
821 int64_t embedding_dim) {
822 if (values == nullptr || num_rows == nullptr) {
823 LOG(ERROR) << "GetPrefetchResultFlat output pointer is null";
824 return false;
825 }
826 if (embedding_dim <= 0) {
827 LOG(ERROR) << "GetPrefetchResultFlat invalid embedding_dim: "
828 << embedding_dim;
829 return false;
830 }
831
832 std::vector<std::vector<float>> merged_values;
833 if (!GetPrefetchResult(prefetch_id, &merged_values)) {
834 return false;
835 }
836
837 *num_rows = static_cast<int64_t>(merged_values.size());
838 values->assign(
839 static_cast<size_t>(*num_rows) * static_cast<size_t>(embedding_dim),
840 0.0f);
841 for (size_t i = 0; i < merged_values.size(); ++i) {
842 const auto& row = merged_values[i];
843 if (row.empty()) {
844 continue;
845 }
846 const int64_t copy_d =
847 std::min<int64_t>(embedding_dim, static_cast<int64_t>(row.size()));
848 std::memcpy(values->data() + i * static_cast<size_t>(embedding_dim),
849 row.data(),
850 static_cast<size_t>(copy_d) * sizeof(float));
851 }
852 return true;
853 }
854
855 } // namespace recstore
856