GCC Code Coverage Report


Directory: src/
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 46.7% 254 / 0 / 544
Functions: 60.0% 18 / 0 / 30
Branches: 19.5% 210 / 0 / 1079

ps/brpc/brpc_ps_client.cpp
Line Branch Exec Source
1 #include "brpc_ps_client.h"
2
3 #include <brpc/channel.h>
4 #include <fmt/core.h>
5
6 #include <cstdint>
7 #include <cstring>
8 #include <future>
9 #include <string>
10 #include <vector>
11
12 #include "base/array.h"
13 #include "base/factory.h"
14 #include "base/flatc.h"
15 #include "base/log.h"
16 #include "base/timer.h"
17 #include "ps/base/parameters.h"
18 #include "ps_brpc.pb.h"
19
20 #ifdef ENABLE_PERF_REPORT
21 # include <chrono>
22 # include "base/report/report_client.h"
23 #endif
24
25 using recstoreps_brpc::CommandRequest;
26 using recstoreps_brpc::CommandResponse;
27 using recstoreps_brpc::GetParameterRequest;
28 using recstoreps_brpc::GetParameterResponse;
29 using recstoreps_brpc::InitEmbeddingTableRequest;
30 using recstoreps_brpc::InitEmbeddingTableResponse;
31 using recstoreps_brpc::PSCommand;
32 using recstoreps_brpc::PutParameterRequest;
33 using recstoreps_brpc::PutParameterResponse;
34 using recstoreps_brpc::UpdateParameterRequest;
35 using recstoreps_brpc::UpdateParameterResponse;
36
37 namespace {
38
39 50 const ParameterCompressReader* ExtractGetResponseReader(
40 const brpc::Controller& cntl,
41 const GetParameterResponse& response,
42 std::string* payload_storage,
43 int* payload_size) {
44
1/2
✓ Branch 2 taken 50 times.
✗ Branch 3 not taken.
50 if (!cntl.response_attachment().empty()) {
45 50 payload_storage->clear();
46 50 cntl.response_attachment().copy_to(payload_storage);
47 50 *payload_size = payload_storage->size();
48 return reinterpret_cast<const ParameterCompressReader*>(
49 50 payload_storage->data());
50 }
51
52 *payload_size = response.parameter_value().size();
53 return reinterpret_cast<const ParameterCompressReader*>(
54 response.parameter_value().data());
55 }
56
57 } // namespace
58
59 namespace {
60
61 int BuildUpdateBlocksFromFlat(
62 const base::ConstArray<uint64_t>& keys,
63 const float* grads,
64 int64_t num_rows,
65 int64_t embedding_dim,
66 ParameterCompressor* compressor) {
67 if (grads == nullptr) {
68 LOG(ERROR) << "UpdateParameterFlat grads pointer is null";
69 return -1;
70 }
71 if (num_rows < 0 || embedding_dim <= 0) {
72 LOG(ERROR) << "UpdateParameterFlat invalid shape: rows=" << num_rows
73 << " dim=" << embedding_dim;
74 return -1;
75 }
76 if (keys.Size() != static_cast<size_t>(num_rows)) {
77 LOG(ERROR) << "UpdateParameterFlat keys/grads size mismatch: "
78 << keys.Size() << " vs " << num_rows;
79 return -1;
80 }
81
82 for (int64_t i = 0; i < num_rows; ++i) {
83 ParameterPack pack;
84 pack.key = keys[static_cast<size_t>(i)];
85 pack.dim = embedding_dim;
86 pack.emb_data = grads + i * embedding_dim;
87 compressor->AddItem(pack, nullptr);
88 }
89 return 0;
90 }
91
92 } // namespace
93
94 DEFINE_int32(brpc_timeout_ms, 5000, "brpc request timeout in milliseconds");
95 DEFINE_int32(brpc_max_retry, 3, "brpc max retry times");
96 DEFINE_bool(parameter_client_random_init_brpc, false, "");
97
98 // New constructor that takes JSON config
99 24 BRPCParameterClient::BRPCParameterClient(json config)
100
1/2
✓ Branch 2 taken 24 times.
✗ Branch 3 not taken.
24 : recstore::BasePSClient(config) {
101
1/2
✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
24 host_ = config.value("host", "localhost");
102
1/2
✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
24 port_ = config.value("port", 15000);
103
1/2
✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
24 shard_ = config.value("shard", 0);
104
1/2
✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
24 timeout_ms_ = config.value("timeout_ms", FLAGS_brpc_timeout_ms);
105
1/2
✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
24 max_retry_ = config.value("max_retry", FLAGS_brpc_max_retry);
106
107
1/2
✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
24 Initialize();
108
109 // Initialize bRPC channel
110
1/2
✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
24 channel_ = std::make_shared<brpc::Channel>();
111
1/2
✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
24 brpc::ChannelOptions options;
112 24 options.timeout_ms = timeout_ms_;
113 24 options.max_retry = max_retry_;
114
115
1/2
✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
24 std::string server_addr = fmt::format("{}:{}", host_, port_);
116
2/4
✓ Branch 3 taken 24 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 24 times.
24 if (channel_->Init(server_addr.c_str(), &options) != 0) {
117 LOG(ERROR) << "Failed to initialize bRPC channel to " << server_addr;
118 } else {
119
3/6
✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 24 times.
✗ Branch 8 not taken.
48 LOG(INFO) << "Initialized bRPC PS Client Shard " << shard_ << " at "
120
3/6
✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 24 times.
✗ Branch 8 not taken.
24 << server_addr;
121 }
122 24 }
123
124 // Legacy constructor for backward compatibility
125 6 BRPCParameterClient::BRPCParameterClient(
126 6 const std::string& host, int port, int shard)
127 : recstore::BasePSClient(
128
14/28
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 9 taken 6 times.
✗ Branch 10 not taken.
✓ Branch 12 taken 6 times.
✗ Branch 13 not taken.
✓ Branch 16 taken 6 times.
✗ Branch 17 not taken.
✓ Branch 19 taken 6 times.
✗ Branch 20 not taken.
✓ Branch 22 taken 18 times.
✓ Branch 23 taken 6 times.
✓ Branch 25 taken 12 times.
✓ Branch 26 taken 6 times.
✓ Branch 28 taken 12 times.
✓ Branch 29 taken 6 times.
✓ Branch 31 taken 12 times.
✓ Branch 32 taken 6 times.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
66 json{{"host", host}, {"port", port}, {"shard", shard}}),
129 6 host_(host),
130 6 port_(port),
131 6 shard_(shard),
132 6 timeout_ms_(FLAGS_brpc_timeout_ms),
133
2/4
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 6 times.
✗ Branch 6 not taken.
18 max_retry_(FLAGS_brpc_max_retry) {
134
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 Initialize();
135
136
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 channel_ = std::make_shared<brpc::Channel>();
137
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 brpc::ChannelOptions options;
138 6 options.timeout_ms = timeout_ms_;
139
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 options.max_retry = max_retry_;
140
141 std::string server_addr = fmt::format("{}:{}", host, port);
142
2/4
✓ Branch 3 taken 6 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 6 times.
6 if (channel_->Init(server_addr.c_str(), &options) != 0) {
143 LOG(ERROR) << "Failed to initialize bRPC channel to " << server_addr;
144 } else {
145
3/6
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 6 times.
✗ Branch 8 not taken.
12 LOG(INFO) << "Initialized bRPC PS Client Shard " << shard_ << " at "
146
3/6
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 6 times.
✗ Branch 8 not taken.
6 << server_addr;
147 }
148 6 }
149
150 30 bool BRPCParameterClient::Initialize() { return true; }
151
152 int BRPCParameterClient::GetParameter(const base::ConstArray<uint64_t>& keys,
153 float* values) {
154 #ifdef ENABLE_PERF_REPORT
155 auto start_time = std::chrono::high_resolution_clock::now();
156 #endif
157
158 if (FLAGS_parameter_client_random_init_brpc) {
159 CHECK(0) << "todo implement";
160 return true;
161 }
162
163 int request_num =
164 (keys.Size() + MAX_PARAMETER_BATCH_BRPC - 1) / MAX_PARAMETER_BATCH_BRPC;
165 std::vector<GetParameterRequest> requests(request_num);
166 std::vector<GetParameterResponse> responses(request_num);
167 std::vector<brpc::Controller> controllers(request_num);
168 std::vector<int> key_sizes;
169
170 // Create stub
171 recstoreps_brpc::ParameterService_Stub stub(channel_.get());
172
173 #ifdef ENABLE_PERF_REPORT
174 auto wait_start_time = std::chrono::high_resolution_clock::now();
175 #endif
176
177 // Send async RPC requests
178 for (int start = 0, index = 0; start < keys.Size();
179 start += MAX_PARAMETER_BATCH_BRPC, ++index) {
180 int key_size =
181 std::min((int)(keys.Size() - start), MAX_PARAMETER_BATCH_BRPC);
182 key_sizes.push_back(key_size);
183
184 controllers[index].request_attachment().append(
185 reinterpret_cast<const char*>(&keys[start]),
186 sizeof(uint64_t) * key_size);
187
188 google::protobuf::Closure* done = brpc::NewCallback([]() { /* no-op */ });
189 stub.GetParameter(
190 &controllers[index], &requests[index], &responses[index], done);
191 }
192
193 // Wait for all RPCs to complete
194 for (int i = 0; i < request_num; ++i) {
195 brpc::Join(controllers[i].call_id());
196 if (controllers[i].Failed()) {
197 LOG(ERROR) << "bRPC GetParameter failed: " << controllers[i].ErrorText();
198 return false;
199 }
200 }
201
202 #ifdef ENABLE_PERF_REPORT
203 auto wait_end_time = std::chrono::high_resolution_clock::now();
204 auto wait_duration =
205 std::chrono::duration_cast<std::chrono::microseconds>(
206 wait_end_time - wait_start_time)
207 .count();
208 double wait_start_us =
209 std::chrono::duration_cast<std::chrono::microseconds>(
210 wait_start_time.time_since_epoch())
211 .count();
212 std::string wait_label =
213 "brpc_client::RPC_Call_And_Wait_Shard" + std::to_string(shard_);
214 FlameGraphData wait_fg = {
215 wait_label,
216 wait_start_us,
217 2, // level
218 static_cast<double>(wait_duration),
219 static_cast<double>(wait_duration)};
220 std::string unique_id =
221 "embread_debug|" + std::to_string(static_cast<uint64_t>(wait_start_us));
222 report_flame_graph("emb_read_flame_map", unique_id.c_str(), wait_fg);
223
224 double start_us_for_rpc =
225 std::chrono::duration_cast<std::chrono::microseconds>(
226 start_time.time_since_epoch())
227 .count();
228 std::string report_id_for_rpc =
229 "brpc_client::GetParameter|" +
230 std::to_string(static_cast<uint64_t>(start_us_for_rpc));
231 report("embread_stages",
232 report_id_for_rpc.c_str(),
233 "rpc_duration_us",
234 static_cast<double>(wait_duration));
235
236 auto deserialize_start_time = std::chrono::high_resolution_clock::now();
237 #endif
238
239 // Parse responses
240 size_t get_embedding_acc = 0;
241 int old_dimension = -1;
242 std::string payload_storage;
243
244 for (int i = 0; i < responses.size(); ++i) {
245 auto& response = responses[i];
246 int key_size = key_sizes[i];
247 int payload_size = 0;
248 auto parameters = ExtractGetResponseReader(
249 controllers[i], response, &payload_storage, &payload_size);
250
251 if (parameters == nullptr || !parameters->Valid(payload_size)) {
252 LOG(ERROR) << "GetParameter invalid payload: " << payload_size;
253 return false;
254 }
255
256 if (parameters->size != key_size) {
257 LOG(ERROR) << "GetParameter error: " << parameters->size << " vs "
258 << key_size;
259 return false;
260 }
261
262 for (int index = 0; index < parameters->item_size(); ++index) {
263 auto item = parameters->item(index);
264 if (item->dim != 0) {
265 if (old_dimension == -1)
266 old_dimension = item->dim;
267 CHECK_EQ(item->dim, old_dimension);
268 std::copy_n(
269 item->embedding, item->dim, values + item->dim * get_embedding_acc);
270 } else {
271 RECSTORE_LOG_EVERY_MS(ERROR, 2000)
272 << "error; not find key " << keys[get_embedding_acc] << " in ps";
273 }
274 get_embedding_acc++;
275 }
276 }
277
278 #ifdef ENABLE_PERF_REPORT
279 auto deserialize_end_time = std::chrono::high_resolution_clock::now();
280 auto deserialize_duration =
281 std::chrono::duration_cast<std::chrono::microseconds>(
282 deserialize_end_time - deserialize_start_time)
283 .count();
284 double deserialize_start_us =
285 std::chrono::duration_cast<std::chrono::microseconds>(
286 deserialize_start_time.time_since_epoch())
287 .count();
288 std::string des_label =
289 "brpc_client::Deserialize_Shard" + std::to_string(shard_);
290 FlameGraphData des_fg = {
291 des_label,
292 deserialize_start_us,
293 2, // level
294 static_cast<double>(deserialize_duration),
295 static_cast<double>(deserialize_duration)};
296 std::string des_unique_id =
297 "embread_debug|" +
298 std::to_string(static_cast<uint64_t>(deserialize_start_us));
299 report_flame_graph("emb_read_flame_map", des_unique_id.c_str(), des_fg);
300
301 double start_us_for_des =
302 std::chrono::duration_cast<std::chrono::microseconds>(
303 start_time.time_since_epoch())
304 .count();
305 std::string report_id_for_des =
306 "brpc_client::GetParameter|" +
307 std::to_string(static_cast<uint64_t>(start_us_for_des));
308 report("embread_stages",
309 report_id_for_des.c_str(),
310 "deserialize_duration_us",
311 static_cast<double>(deserialize_duration));
312 #endif
313
314 #ifdef ENABLE_PERF_REPORT
315 auto end_time = std::chrono::high_resolution_clock::now();
316 auto duration =
317 std::chrono::duration_cast<std::chrono::microseconds>(
318 end_time - start_time)
319 .count();
320 report("ps_client_latency",
321 "GetParameter",
322 "latency_us",
323 static_cast<double>(duration));
324
325 double start_us =
326 std::chrono::duration_cast<std::chrono::microseconds>(
327 start_time.time_since_epoch())
328 .count();
329 FlameGraphData fg_data = {
330 "brpc_client::GetParameter",
331 start_us,
332 1, // level
333 static_cast<double>(duration),
334 static_cast<double>(duration)};
335
336 std::string report_id = "brpc_client::GetParameter|" +
337 std::to_string(static_cast<uint64_t>(start_us));
338
339 report("embread_stages",
340 report_id.c_str(),
341 "duration_us",
342 static_cast<double>(duration));
343
344 report("embread_stages",
345 report_id.c_str(),
346 "request_size",
347 static_cast<double>(keys.Size()));
348
349 std::string final_unique_id =
350 "embread_debug|" + std::to_string(static_cast<uint64_t>(start_us));
351 report_flame_graph("emb_read_flame_map", final_unique_id.c_str(), fg_data);
352 #endif
353
354 return true;
355 }
356
357 40 int BRPCParameterClient::GetParameter(const base::ConstArray<uint64_t>& keys,
358 std::vector<std::vector<float>>* values) {
359 #ifdef ENABLE_PERF_REPORT
360 auto start_time = std::chrono::high_resolution_clock::now();
361 #endif
362
363
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 40 times.
40 if (FLAGS_parameter_client_random_init_brpc) {
364 values->clear();
365 values->reserve(keys.Size());
366 for (size_t i = 0; i < keys.Size(); i++)
367 values->emplace_back(std::vector<float>(128, 0.1));
368 return true;
369 }
370
371 40 values->clear();
372
1/2
✓ Branch 2 taken 40 times.
✗ Branch 3 not taken.
40 values->reserve(keys.Size());
373
374 int request_num =
375 40 (keys.Size() + MAX_PARAMETER_BATCH_BRPC - 1) / MAX_PARAMETER_BATCH_BRPC;
376
1/2
✓ Branch 2 taken 40 times.
✗ Branch 3 not taken.
40 std::vector<GetParameterRequest> requests(request_num);
377
1/2
✓ Branch 2 taken 40 times.
✗ Branch 3 not taken.
40 std::vector<GetParameterResponse> responses(request_num);
378
1/2
✓ Branch 2 taken 40 times.
✗ Branch 3 not taken.
40 std::vector<brpc::Controller> controllers(request_num);
379 40 std::vector<int> key_sizes;
380
381
1/2
✓ Branch 2 taken 40 times.
✗ Branch 3 not taken.
40 recstoreps_brpc::ParameterService_Stub stub(channel_.get());
382
383 #ifdef ENABLE_PERF_REPORT
384 auto wait_start_time = std::chrono::high_resolution_clock::now();
385 #endif
386
387 // Send async RPC requests
388
2/2
✓ Branch 1 taken 40 times.
✓ Branch 2 taken 40 times.
80 for (int start = 0, index = 0; start < keys.Size();
389 40 start += MAX_PARAMETER_BATCH_BRPC, ++index) {
390 int key_size =
391 40 std::min((int)(keys.Size() - start), MAX_PARAMETER_BATCH_BRPC);
392
1/2
✓ Branch 1 taken 40 times.
✗ Branch 2 not taken.
40 key_sizes.push_back(key_size);
393
394
1/2
✓ Branch 3 taken 40 times.
✗ Branch 4 not taken.
80 controllers[index].request_attachment().append(
395 40 reinterpret_cast<const char*>(&keys[start]),
396 40 sizeof(uint64_t) * key_size);
397
398
1/2
✓ Branch 2 taken 40 times.
✗ Branch 3 not taken.
40 google::protobuf::Closure* done = brpc::NewCallback([]() { /* no-op */ });
399
1/2
✓ Branch 1 taken 40 times.
✗ Branch 2 not taken.
40 stub.GetParameter(
400 40 &controllers[index], &requests[index], &responses[index], done);
401 }
402
403 // Wait for all RPCs to complete
404
2/2
✓ Branch 0 taken 40 times.
✓ Branch 1 taken 40 times.
80 for (int i = 0; i < request_num; ++i) {
405
2/4
✓ Branch 2 taken 40 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 40 times.
✗ Branch 6 not taken.
40 brpc::Join(controllers[i].call_id());
406
2/4
✓ Branch 2 taken 40 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 40 times.
40 if (controllers[i].Failed()) {
407 LOG(ERROR) << "bRPC GetParameter failed: " << controllers[i].ErrorText();
408 return false;
409 }
410 }
411
412 #ifdef ENABLE_PERF_REPORT
413 auto wait_end_time = std::chrono::high_resolution_clock::now();
414 auto wait_duration =
415 std::chrono::duration_cast<std::chrono::microseconds>(
416 wait_end_time - wait_start_time)
417 .count();
418 double wait_start_us =
419 std::chrono::duration_cast<std::chrono::microseconds>(
420 wait_start_time.time_since_epoch())
421 .count();
422 std::string wait_label =
423 "brpc_client::RPC_Call_And_Wait_Shard" + std::to_string(shard_);
424 FlameGraphData wait_fg = {
425 wait_label,
426 wait_start_us,
427 2, // level
428 static_cast<double>(wait_duration),
429 static_cast<double>(wait_duration)};
430 std::string unique_id =
431 "embread_debug|" + std::to_string(static_cast<uint64_t>(wait_start_us));
432 report_flame_graph("emb_read_flame_map", unique_id.c_str(), wait_fg);
433
434 double start_us_for_rpc =
435 std::chrono::duration_cast<std::chrono::microseconds>(
436 start_time.time_since_epoch())
437 .count();
438 std::string report_id_for_rpc =
439 "brpc_client::GetParameter_Vec|" +
440 std::to_string(static_cast<uint64_t>(start_us_for_rpc));
441
442 report("embread_stages",
443 report_id_for_rpc.c_str(),
444 "rpc_duration_us",
445 static_cast<double>(wait_duration));
446
447 auto deserialize_start_time = std::chrono::high_resolution_clock::now();
448 #endif
449
450 // Parse responses
451 40 std::string payload_storage;
452
2/2
✓ Branch 1 taken 40 times.
✓ Branch 2 taken 40 times.
80 for (int i = 0; i < responses.size(); ++i) {
453 40 auto& response = responses[i];
454 40 int key_size = key_sizes[i];
455 40 int payload_size = 0;
456
1/2
✓ Branch 1 taken 40 times.
✗ Branch 2 not taken.
40 auto parameters = ExtractGetResponseReader(
457 40 controllers[i], response, &payload_storage, &payload_size);
458
459
4/8
✓ Branch 0 taken 40 times.
✗ Branch 1 not taken.
✓ Branch 3 taken 40 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 40 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 40 times.
40 if (parameters == nullptr || !parameters->Valid(payload_size)) {
460 LOG(ERROR) << "GetParameter(vector) invalid payload: " << payload_size;
461 return false;
462 }
463
464
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 40 times.
40 if (unlikely(parameters->size != key_size)) {
465 LOG(ERROR) << "GetParameter error: " << parameters->size << " vs "
466 << key_size;
467 return false;
468 }
469
470
2/2
✓ Branch 1 taken 266 times.
✓ Branch 2 taken 40 times.
306 for (int index = 0; index < parameters->item_size(); ++index) {
471
1/2
✓ Branch 1 taken 266 times.
✗ Branch 2 not taken.
266 auto item = parameters->item(index);
472
2/2
✓ Branch 0 taken 254 times.
✓ Branch 1 taken 12 times.
266 if (item->dim != 0) {
473
1/2
✓ Branch 1 taken 254 times.
✗ Branch 2 not taken.
254 values->emplace_back(
474
1/2
✓ Branch 2 taken 254 times.
✗ Branch 3 not taken.
508 std::vector<float>(item->embedding, item->embedding + item->dim));
475 } else {
476
2/4
✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 12 times.
✗ Branch 6 not taken.
12 values->emplace_back(std::vector<float>(0));
477 }
478 }
479 }
480
481 #ifdef ENABLE_PERF_REPORT
482 auto deserialize_end_time = std::chrono::high_resolution_clock::now();
483 auto deserialize_duration =
484 std::chrono::duration_cast<std::chrono::microseconds>(
485 deserialize_end_time - deserialize_start_time)
486 .count();
487 double deserialize_start_us =
488 std::chrono::duration_cast<std::chrono::microseconds>(
489 deserialize_start_time.time_since_epoch())
490 .count();
491 std::string des_label =
492 "brpc_client::Deserialize_Shard" + std::to_string(shard_);
493 FlameGraphData des_fg = {
494 des_label,
495 deserialize_start_us,
496 2, // level
497 static_cast<double>(deserialize_duration),
498 static_cast<double>(deserialize_duration)};
499 std::string des_unique_id =
500 "embread_debug|" +
501 std::to_string(static_cast<uint64_t>(deserialize_start_us));
502 report_flame_graph("emb_read_flame_map", des_unique_id.c_str(), des_fg);
503
504 double start_us_for_des =
505 std::chrono::duration_cast<std::chrono::microseconds>(
506 start_time.time_since_epoch())
507 .count();
508 std::string report_id_for_des =
509 "brpc_client::GetParameter_Vec|" +
510 std::to_string(static_cast<uint64_t>(start_us_for_des));
511 report("embread_stages",
512 report_id_for_des.c_str(),
513 "deserialize_duration_us",
514 static_cast<double>(deserialize_duration));
515 #endif
516
517 #ifdef ENABLE_PERF_REPORT
518 auto end_time = std::chrono::high_resolution_clock::now();
519 auto duration =
520 std::chrono::duration_cast<std::chrono::microseconds>(
521 end_time - start_time)
522 .count();
523 report("ps_client_latency",
524 "GetParameter",
525 "latency_us",
526 static_cast<double>(duration));
527
528 double start_us =
529 std::chrono::duration_cast<std::chrono::microseconds>(
530 start_time.time_since_epoch())
531 .count();
532 FlameGraphData fg_data = {
533 "brpc_client::GetParameter_Vec",
534 start_us,
535 1, // level
536 static_cast<double>(duration),
537 static_cast<double>(duration)};
538
539 std::string report_id = "brpc_client::GetParameter_Vec|" +
540 std::to_string(static_cast<uint64_t>(start_us));
541
542 report("embread_stages",
543 report_id.c_str(),
544 "duration_us",
545 static_cast<double>(duration));
546
547 report("embread_stages",
548 report_id.c_str(),
549 "request_size",
550 static_cast<double>(keys.Size()));
551
552 std::string final_unique_id =
553 "embread_debug|" + std::to_string(static_cast<uint64_t>(start_us));
554 report_flame_graph("emb_read_flame_map", final_unique_id.c_str(), fg_data);
555 #endif
556
557 40 return true;
558 40 }
559
560 10 static void OnPrefetchDone(BrpcPrefetchBatch* batch) {
561 10 batch->completed_count_++;
562 10 }
563
564 8 static void OnPrewriteDone(BrpcPrewriteBatch* batch) {
565 8 batch->completed_count_++;
566 8 }
567
568 uint64_t
569 10 BRPCParameterClient::PrefetchParameter(const base::ConstArray<uint64_t>& keys) {
570 10 uint64_t prefetch_id = next_prefetch_id_++;
571 int request_num =
572 10 (keys.Size() + MAX_PARAMETER_BATCH_BRPC - 1) / MAX_PARAMETER_BATCH_BRPC;
573
574 // Construct in map so batch pointers stay valid
575
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 auto it = prefetch_batches_.emplace(prefetch_id, request_num).first;
576 10 struct BrpcPrefetchBatch* pb = &it->second;
577
578
1/2
✓ Branch 2 taken 10 times.
✗ Branch 3 not taken.
10 recstoreps_brpc::ParameterService_Stub stub(channel_.get());
579
580
2/2
✓ Branch 1 taken 10 times.
✓ Branch 2 taken 10 times.
20 for (int start = 0, index = 0; start < keys.Size();
581 10 start += MAX_PARAMETER_BATCH_BRPC, ++index) {
582 int key_size =
583 10 std::min((int)(keys.Size() - start), MAX_PARAMETER_BATCH_BRPC);
584 10 pb->key_sizes_[index] = key_size;
585
586
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 GetParameterRequest request;
587
588
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 pb->controllers_[index] = std::make_unique<brpc::Controller>();
589
1/2
✓ Branch 4 taken 10 times.
✗ Branch 5 not taken.
20 pb->controllers_[index]->request_attachment().append(
590 10 reinterpret_cast<const char*>(&keys[start]),
591 10 sizeof(uint64_t) * key_size);
592
593
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 google::protobuf::Closure* done = brpc::NewCallback(OnPrefetchDone, pb);
594
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 stub.GetParameter(
595 10 pb->controllers_[index].get(), &request, &pb->responses_[index], done);
596 10 }
597
598 10 return prefetch_id;
599 10 }
600
601 bool BRPCParameterClient::IsPrefetchDone(uint64_t prefetch_id) {
602 auto it = prefetch_batches_.find(prefetch_id);
603 if (it == prefetch_batches_.end()) {
604 LOG(ERROR) << "Invalid prefetch_id: " << prefetch_id;
605 return false;
606 }
607
608 auto& pb = it->second;
609
610 return pb.completed_count_ == pb.batch_size_;
611 }
612
613 10 void BRPCParameterClient::WaitForPrefetch(uint64_t prefetch_id) {
614
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 auto it = prefetch_batches_.find(prefetch_id);
615
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 10 times.
10 if (it == prefetch_batches_.end()) {
616 LOG(ERROR) << "Invalid prefetch_id: " << prefetch_id;
617 return;
618 }
619 10 auto& pb = it->second;
620
2/2
✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
20 for (int i = 0; i < pb.batch_size_; ++i) {
621
1/2
✓ Branch 2 taken 10 times.
✗ Branch 3 not taken.
10 if (pb.controllers_[i]) {
622
2/4
✓ Branch 3 taken 10 times.
✗ Branch 4 not taken.
✓ Branch 6 taken 10 times.
✗ Branch 7 not taken.
10 brpc::Join(pb.controllers_[i]->call_id());
623 }
624 }
625 10 pb.completed_count_ = pb.batch_size_;
626 }
627
628 10 bool BRPCParameterClient::GetPrefetchResult(
629 uint64_t prefetch_id, std::vector<std::vector<float>>* values) {
630
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 auto it = prefetch_batches_.find(prefetch_id);
631
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 10 times.
10 if (it == prefetch_batches_.end()) {
632 LOG(ERROR) << "Invalid prefetch_id: " << prefetch_id;
633 return false;
634 }
635
636 10 auto& pb = it->second;
637 10 int request_num = pb.batch_size_;
638
639 10 values->clear();
640 10 int keys_size = 0;
641
2/2
✓ Branch 4 taken 10 times.
✓ Branch 5 taken 10 times.
20 for (const auto& size : pb.key_sizes_) {
642 10 keys_size += size;
643 }
644
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 values->reserve(keys_size);
645
646
2/2
✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
20 for (int i = 0; i < request_num; ++i) {
647
2/4
✓ Branch 3 taken 10 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 10 times.
10 if (pb.controllers_[i]->Failed()) {
648 LOG(ERROR) << "Prefetch request failed: "
649 << pb.controllers_[i]->ErrorText();
650 return false;
651 }
652
653 10 auto& response = pb.responses_[i];
654 10 int key_size = pb.key_sizes_[i];
655 10 std::string payload_storage;
656 10 int payload_size = 0;
657
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 auto parameters = ExtractGetResponseReader(
658 10 *pb.controllers_[i], response, &payload_storage, &payload_size);
659
660
4/8
✓ Branch 0 taken 10 times.
✗ Branch 1 not taken.
✓ Branch 3 taken 10 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 10 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 10 times.
10 if (parameters == nullptr || !parameters->Valid(payload_size)) {
661 LOG(ERROR) << "Prefetch invalid payload: " << payload_size;
662 return false;
663 }
664
665
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 10 times.
10 if (unlikely(parameters->size != key_size)) {
666 LOG(ERROR) << "GetParameter error: " << parameters->size << " vs "
667 << key_size;
668 return false;
669 }
670
671
2/2
✓ Branch 1 taken 106 times.
✓ Branch 2 taken 10 times.
116 for (int index = 0; index < parameters->item_size(); ++index) {
672
1/2
✓ Branch 1 taken 106 times.
✗ Branch 2 not taken.
106 auto item = parameters->item(index);
673
1/2
✓ Branch 0 taken 106 times.
✗ Branch 1 not taken.
106 if (item->dim != 0) {
674
1/2
✓ Branch 1 taken 106 times.
✗ Branch 2 not taken.
106 values->emplace_back(
675
1/2
✓ Branch 2 taken 106 times.
✗ Branch 3 not taken.
212 std::vector<float>(item->embedding, item->embedding + item->dim));
676 } else {
677 values->emplace_back(std::vector<float>(0));
678 }
679 }
680
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 }
681
682 // Remove completed batch
683
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 prefetch_batches_.erase(it);
684
685 10 return true;
686 }
687
688 bool BRPCParameterClient::GetPrefetchResultFlat(
689 uint64_t prefetch_id,
690 std::vector<float>* values,
691 int64_t* num_rows,
692 int64_t embedding_dim) {
693 auto it = prefetch_batches_.find(prefetch_id);
694 if (it == prefetch_batches_.end()) {
695 LOG(ERROR) << "Invalid prefetch_id: " << prefetch_id;
696 return false;
697 }
698 if (values == nullptr || num_rows == nullptr) {
699 LOG(ERROR) << "GetPrefetchResultFlat output pointer is null";
700 return false;
701 }
702
703 auto& pb = it->second;
704 int request_num = pb.batch_size_;
705 int total_keys = 0;
706 for (const auto& size : pb.key_sizes_) {
707 total_keys += size;
708 }
709
710 *num_rows = static_cast<int64_t>(total_keys);
711 values->assign(
712 static_cast<size_t>(*num_rows) * static_cast<size_t>(embedding_dim),
713 0.0f);
714
715 size_t row_offset = 0;
716 for (int i = 0; i < request_num; ++i) {
717 if (pb.controllers_[i]->Failed()) {
718 LOG(ERROR) << "Prefetch request failed: "
719 << pb.controllers_[i]->ErrorText();
720 return false;
721 }
722
723 auto& response = pb.responses_[i];
724 int key_size = pb.key_sizes_[i];
725 std::string payload_storage;
726 int payload_size = 0;
727 auto parameters = ExtractGetResponseReader(
728 *pb.controllers_[i], response, &payload_storage, &payload_size);
729
730 if (parameters == nullptr || !parameters->Valid(payload_size)) {
731 LOG(ERROR) << "Prefetch invalid payload: " << payload_size;
732 return false;
733 }
734
735 if (unlikely(parameters->size != key_size)) {
736 LOG(ERROR) << "GetParameter error: " << parameters->size << " vs "
737 << key_size;
738 return false;
739 }
740
741 for (int index = 0; index < parameters->item_size();
742 ++index, ++row_offset) {
743 auto item = parameters->item(index);
744 if (item->dim != 0) {
745 const int64_t copy_d =
746 std::min<int64_t>(embedding_dim, static_cast<int64_t>(item->dim));
747 std::memcpy(values->data() + row_offset * embedding_dim,
748 item->embedding,
749 static_cast<size_t>(copy_d) * sizeof(float));
750 }
751 }
752 }
753
754 prefetch_batches_.erase(it);
755 return true;
756 }
757
758 34 bool BRPCParameterClient::ClearPS() {
759
1/2
✓ Branch 1 taken 34 times.
✗ Branch 2 not taken.
34 CommandRequest request;
760
1/2
✓ Branch 1 taken 34 times.
✗ Branch 2 not taken.
34 CommandResponse response;
761
1/2
✓ Branch 1 taken 34 times.
✗ Branch 2 not taken.
34 request.set_command(PSCommand::CLEAR_PS);
762
763
1/2
✓ Branch 1 taken 34 times.
✗ Branch 2 not taken.
34 brpc::Controller cntl;
764
1/2
✓ Branch 2 taken 34 times.
✗ Branch 3 not taken.
34 recstoreps_brpc::ParameterService_Stub stub(channel_.get());
765
1/2
✓ Branch 1 taken 34 times.
✗ Branch 2 not taken.
34 stub.Command(&cntl, &request, &response, nullptr);
766
767
2/4
✓ Branch 1 taken 34 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 34 times.
34 if (cntl.Failed()) {
768 LOG(ERROR) << "bRPC Command failed: " << cntl.ErrorText();
769 return false;
770 }
771 34 return true;
772 34 }
773
774 6 bool BRPCParameterClient::LoadFakeData(int64_t data) {
775
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 CommandRequest request;
776
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 CommandResponse response;
777
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 request.set_command(PSCommand::LOAD_FAKE_DATA);
778
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 request.add_arg1(&data, sizeof(int64_t));
779
780
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 brpc::Controller cntl;
781
1/2
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
6 recstoreps_brpc::ParameterService_Stub stub(channel_.get());
782
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 stub.Command(&cntl, &request, &response, nullptr);
783
784
2/4
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
6 if (cntl.Failed()) {
785 LOG(ERROR) << "bRPC LoadFakeData failed: " << cntl.ErrorText();
786 return false;
787 }
788
2/4
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 6 times.
6 if (response.reply().size() != static_cast<size_t>(data)) {
789 LOG(ERROR) << "bRPC LoadFakeData reply size mismatch: expected " << data
790 << ", got " << response.reply().size();
791 return false;
792 }
793 6 return true;
794 6 }
795
796 6 bool BRPCParameterClient::DumpFakeData(int64_t n) {
797
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 CommandRequest request;
798
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 CommandResponse response;
799
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 request.set_command(PSCommand::DUMP_FAKE_DATA);
800
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 request.add_arg1(&n, sizeof(int64_t));
801
802
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 brpc::Controller cntl;
803
1/2
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
6 recstoreps_brpc::ParameterService_Stub stub(channel_.get());
804
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 stub.Command(&cntl, &request, &response, nullptr);
805
806
2/4
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
6 if (cntl.Failed()) {
807 LOG(ERROR) << "bRPC DumpFakeData failed: " << cntl.ErrorText();
808 return false;
809 }
810
2/4
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 6 times.
6 if (response.reply() != "ok") {
811 LOG(ERROR) << "bRPC DumpFakeData unexpected reply: " << response.reply();
812 return false;
813 }
814 6 return true;
815 6 }
816
817 bool BRPCParameterClient::LoadCkpt(
818 const std::vector<std::string>& model_config_path,
819 const std::vector<std::string>& emb_file_path) {
820 CommandRequest request;
821 CommandResponse response;
822 request.set_command(PSCommand::RELOAD_PS);
823
824 for (auto& each : model_config_path) {
825 request.add_arg1(each);
826 }
827 for (auto& each : emb_file_path) {
828 request.add_arg2(each);
829 }
830
831 brpc::Controller cntl;
832 recstoreps_brpc::ParameterService_Stub stub(channel_.get());
833 stub.Command(&cntl, &request, &response, nullptr);
834
835 if (cntl.Failed()) {
836 LOG(ERROR) << "bRPC LoadCkpt failed: " << cntl.ErrorText();
837 return false;
838 }
839 return true;
840 }
841
842 20 bool BRPCParameterClient::PutParameter(
843 const std::vector<uint64_t>& keys,
844 const std::vector<std::vector<float>>& values) {
845 #ifdef ENABLE_PERF_REPORT
846 auto start_time = std::chrono::high_resolution_clock::now();
847 #endif
848
849
1/2
✓ Branch 2 taken 20 times.
✗ Branch 3 not taken.
20 recstoreps_brpc::ParameterService_Stub stub(channel_.get());
850
851
2/2
✓ Branch 1 taken 20 times.
✓ Branch 2 taken 20 times.
40 for (int start = 0, index = 0; start < keys.size();
852 20 start += MAX_PARAMETER_BATCH_BRPC, ++index) {
853 int key_size =
854 20 std::min((int)(keys.size() - start), MAX_PARAMETER_BATCH_BRPC);
855
856
1/2
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
20 PutParameterRequest request;
857
1/2
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
20 PutParameterResponse response;
858
1/2
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
20 ParameterCompressor compressor;
859
860
2/2
✓ Branch 0 taken 234 times.
✓ Branch 1 taken 20 times.
254 for (int i = start; i < start + key_size; i++) {
861 234 auto each_key = keys[i];
862 234 auto& embedding = values[i];
863 234 ParameterPack parameter_pack;
864 234 parameter_pack.key = each_key;
865 234 parameter_pack.dim = embedding.size();
866 234 parameter_pack.emb_data = embedding.data();
867
1/2
✓ Branch 1 taken 234 times.
✗ Branch 2 not taken.
234 compressor.AddItem(parameter_pack, nullptr);
868 }
869
870
1/2
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
20 brpc::Controller cntl;
871
1/2
✓ Branch 2 taken 20 times.
✗ Branch 3 not taken.
20 compressor.AppendToIOBuf(&cntl.request_attachment());
872
1/2
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
20 stub.PutParameter(&cntl, &request, &response, nullptr);
873
874
2/4
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 20 times.
20 if (cntl.Failed()) {
875 LOG(ERROR) << "bRPC PutParameter failed: " << cntl.ErrorText();
876 return false;
877 }
878
4/8
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 20 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 20 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 20 times.
✗ Branch 11 not taken.
20 }
879
880 #ifdef ENABLE_PERF_REPORT
881 auto end_time = std::chrono::high_resolution_clock::now();
882 auto duration =
883 std::chrono::duration_cast<std::chrono::microseconds>(
884 end_time - start_time)
885 .count();
886 report("ps_client_latency",
887 "PutParameter",
888 "latency_us",
889 static_cast<double>(duration));
890 #endif
891
892 20 return true;
893 20 }
894
895 int BRPCParameterClient::AsyncGetParameter(
896 const base::ConstArray<uint64_t>& keys, float* values) {
897 return GetParameter(keys, values);
898 }
899
900 14 int BRPCParameterClient::PutParameter(
901 const base::ConstArray<uint64_t>& keys,
902 const std::vector<std::vector<float>>& values) {
903
1/2
✓ Branch 5 taken 14 times.
✗ Branch 6 not taken.
14 std::vector<uint64_t> key_vec(keys.Data(), keys.Data() + keys.Size());
904
1/2
✓ Branch 1 taken 14 times.
✗ Branch 2 not taken.
14 bool success = PutParameter(key_vec, values);
905
1/2
✓ Branch 0 taken 14 times.
✗ Branch 1 not taken.
28 return success ? 1 : 0;
906 14 }
907
908 void BRPCParameterClient::Command(recstore::PSCommand command) {
909 switch (command) {
910 case recstore::PSCommand::CLEAR_PS:
911 ClearPS();
912 break;
913 case recstore::PSCommand::RELOAD_PS:
914 LOG(WARNING) << "RELOAD_PS command requires additional parameters";
915 break;
916 case recstore::PSCommand::LOAD_FAKE_DATA: {
917 int64_t fake_data = 1000;
918 LoadFakeData(fake_data);
919 } break;
920 case recstore::PSCommand::DUMP_FAKE_DATA: {
921 DumpFakeData(4096);
922 } break;
923 default:
924 LOG(ERROR) << "Unknown PS command: " << static_cast<int>(command);
925 break;
926 }
927 }
928
929 int BRPCParameterClient::UpdateParameter(
930 const std::string& table_name,
931 const base::ConstArray<uint64_t>& keys,
932 const std::vector<std::vector<float>>* grads) {
933 #ifdef ENABLE_PERF_REPORT
934 auto start_time = std::chrono::high_resolution_clock::now();
935 const uint64_t trace_id = recstore::g_trace_id;
936 #endif
937 if (grads == nullptr) {
938 LOG(ERROR) << "UpdateParameter grads pointer is null";
939 return -1;
940 }
941 if (keys.Size() != grads->size()) {
942 LOG(ERROR) << "UpdateParameter keys/grads size mismatch: " << keys.Size()
943 << " vs " << grads->size();
944 return -1;
945 }
946
947 ParameterCompressor compressor;
948 for (size_t i = 0; i < keys.Size(); ++i) {
949 ParameterPack pack;
950 pack.key = keys[i];
951 pack.dim = grads->at(i).size();
952 pack.emb_data = grads->at(i).data();
953 compressor.AddItem(pack, nullptr);
954 }
955 #ifdef ENABLE_PERF_REPORT
956 auto serialize_done_time = std::chrono::high_resolution_clock::now();
957 #endif
958 if (keys.Size() == 0) {
959 LOG(WARNING) << "UpdateParameter no gradients to send";
960 return 0;
961 }
962
963 UpdateParameterRequest request;
964 UpdateParameterResponse response;
965 request.set_table_name(table_name);
966
967 brpc::Controller cntl;
968 #ifdef ENABLE_PERF_REPORT
969 if (trace_id != 0) {
970 cntl.http_request().SetHeader(
971 "x-recstore-trace-id", std::to_string(trace_id));
972 }
973 auto rpc_start_time = std::chrono::high_resolution_clock::now();
974 #endif
975 compressor.AppendToIOBuf(&cntl.request_attachment());
976 recstoreps_brpc::ParameterService_Stub stub(channel_.get());
977 stub.UpdateParameter(&cntl, &request, &response, nullptr);
978 if (cntl.Failed()) {
979 LOG(ERROR) << "UpdateParameter RPC failed: " << cntl.ErrorText();
980 return -1;
981 }
982
983 #ifdef ENABLE_PERF_REPORT
984 auto end_time = std::chrono::high_resolution_clock::now();
985 auto duration =
986 std::chrono::duration_cast<std::chrono::microseconds>(
987 end_time - start_time)
988 .count();
989 auto serialize_duration =
990 std::chrono::duration_cast<std::chrono::microseconds>(
991 serialize_done_time - start_time)
992 .count();
993 auto rpc_duration =
994 std::chrono::duration_cast<std::chrono::microseconds>(
995 end_time - rpc_start_time)
996 .count();
997 report("ps_client_latency",
998 "UpdateParameter",
999 "latency_us",
1000 static_cast<double>(duration));
1001
1002 const uint64_t effective_trace_id =
1003 trace_id == 0
1004 ? static_cast<uint64_t>(
1005 std::chrono::duration_cast<std::chrono::microseconds>(
1006 start_time.time_since_epoch())
1007 .count())
1008 : trace_id;
1009 std::string stage_id =
1010 "brpc_client::EmbUpdate|" + std::to_string(effective_trace_id);
1011 report("embupdate_stages",
1012 stage_id.c_str(),
1013 "client_serialize_us",
1014 static_cast<double>(serialize_duration));
1015 report("embupdate_stages",
1016 stage_id.c_str(),
1017 "client_rpc_us",
1018 static_cast<double>(rpc_duration));
1019 report("embupdate_stages",
1020 stage_id.c_str(),
1021 "client_total_us",
1022 static_cast<double>(duration));
1023 report("embupdate_stages",
1024 stage_id.c_str(),
1025 "client_request_size",
1026 static_cast<double>(keys.Size()));
1027 #endif
1028
1029 return response.success() ? 0 : -1;
1030 }
1031
1032 int BRPCParameterClient::UpdateParameterFlat(
1033 const std::string& table_name,
1034 const base::ConstArray<uint64_t>& keys,
1035 const float* grads,
1036 int64_t num_rows,
1037 int64_t embedding_dim) {
1038 #ifdef ENABLE_PERF_REPORT
1039 auto start_time = std::chrono::high_resolution_clock::now();
1040 const uint64_t trace_id = recstore::g_trace_id;
1041 #endif
1042 if (keys.Size() == 0) {
1043 return 0;
1044 }
1045
1046 ParameterCompressor compressor;
1047 if (BuildUpdateBlocksFromFlat(
1048 keys, grads, num_rows, embedding_dim, &compressor) != 0) {
1049 return -1;
1050 }
1051 #ifdef ENABLE_PERF_REPORT
1052 auto serialize_done_time = std::chrono::high_resolution_clock::now();
1053 #endif
1054
1055 UpdateParameterRequest request;
1056 UpdateParameterResponse response;
1057 request.set_table_name(table_name);
1058
1059 brpc::Controller cntl;
1060 #ifdef ENABLE_PERF_REPORT
1061 if (trace_id != 0) {
1062 cntl.http_request().SetHeader(
1063 "x-recstore-trace-id", std::to_string(trace_id));
1064 }
1065 auto rpc_start_time = std::chrono::high_resolution_clock::now();
1066 #endif
1067 compressor.AppendToIOBuf(&cntl.request_attachment());
1068 recstoreps_brpc::ParameterService_Stub stub(channel_.get());
1069 stub.UpdateParameter(&cntl, &request, &response, nullptr);
1070 if (cntl.Failed()) {
1071 LOG(ERROR) << "UpdateParameterFlat RPC failed: " << cntl.ErrorText();
1072 return -1;
1073 }
1074
1075 #ifdef ENABLE_PERF_REPORT
1076 auto end_time = std::chrono::high_resolution_clock::now();
1077 auto duration =
1078 std::chrono::duration_cast<std::chrono::microseconds>(
1079 end_time - start_time)
1080 .count();
1081 auto serialize_duration =
1082 std::chrono::duration_cast<std::chrono::microseconds>(
1083 serialize_done_time - start_time)
1084 .count();
1085 auto rpc_duration =
1086 std::chrono::duration_cast<std::chrono::microseconds>(
1087 end_time - rpc_start_time)
1088 .count();
1089 report("ps_client_latency",
1090 "UpdateParameterFlat",
1091 "latency_us",
1092 static_cast<double>(duration));
1093
1094 const uint64_t effective_trace_id =
1095 trace_id == 0
1096 ? static_cast<uint64_t>(
1097 std::chrono::duration_cast<std::chrono::microseconds>(
1098 start_time.time_since_epoch())
1099 .count())
1100 : trace_id;
1101 std::string stage_id =
1102 "brpc_client::EmbUpdate|" + std::to_string(effective_trace_id);
1103 report("embupdate_stages",
1104 stage_id.c_str(),
1105 "client_serialize_us",
1106 static_cast<double>(serialize_duration));
1107 report("embupdate_stages",
1108 stage_id.c_str(),
1109 "client_rpc_us",
1110 static_cast<double>(rpc_duration));
1111 report("embupdate_stages",
1112 stage_id.c_str(),
1113 "client_total_us",
1114 static_cast<double>(duration));
1115 report("embupdate_stages",
1116 stage_id.c_str(),
1117 "client_request_size",
1118 static_cast<double>(num_rows));
1119 report("embupdate_stages",
1120 stage_id.c_str(),
1121 "client_embedding_dim",
1122 static_cast<double>(embedding_dim));
1123 #endif
1124
1125 return response.success() ? 0 : -1;
1126 }
1127
1128 int BRPCParameterClient::InitEmbeddingTable(
1129 const std::string& table_name,
1130 const recstore::EmbeddingTableConfig& config) {
1131 #ifdef ENABLE_PERF_REPORT
1132 auto start_time = std::chrono::high_resolution_clock::now();
1133 #endif
1134
1135 InitEmbeddingTableRequest request;
1136 InitEmbeddingTableResponse response;
1137 request.set_table_name(table_name);
1138 request.set_config_payload(config.Serialize());
1139
1140 brpc::Controller cntl;
1141 recstoreps_brpc::ParameterService_Stub stub(channel_.get());
1142 stub.InitEmbeddingTable(&cntl, &request, &response, nullptr);
1143 if (cntl.Failed()) {
1144 LOG(ERROR) << "InitEmbeddingTable RPC failed: " << cntl.ErrorText();
1145 return -1;
1146 }
1147
1148 #ifdef ENABLE_PERF_REPORT
1149 auto end_time = std::chrono::high_resolution_clock::now();
1150 auto duration =
1151 std::chrono::duration_cast<std::chrono::microseconds>(
1152 end_time - start_time)
1153 .count();
1154 report("ps_client_latency",
1155 "InitEmbeddingTable",
1156 "latency_us",
1157 static_cast<double>(duration));
1158 #endif
1159
1160 return response.success() ? 0 : -1;
1161 }
1162
1163 8 uint64_t BRPCParameterClient::EmbWriteAsync(const base::RecTensor& keys,
1164 const base::RecTensor& values) {
1165
3/6
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 8 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 8 times.
8 if (keys.dtype() != base::DataType::UINT64 || keys.dim() != 1) {
1166 LOG(ERROR) << "EmbWriteAsync expects keys as 1D UINT64 tensor, got dtype="
1167 << base::DataTypeToString(keys.dtype())
1168 << ", dim=" << keys.dim();
1169 return 0;
1170 }
1171
3/6
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 8 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 8 times.
8 if (values.dtype() != base::DataType::FLOAT32 || values.dim() != 2) {
1172 LOG(ERROR)
1173 << "EmbWriteAsync expects values as 2D FLOAT32 tensor, got dtype="
1174 << base::DataTypeToString(values.dtype()) << ", dim=" << values.dim();
1175 return 0;
1176 }
1177
3/6
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 8 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 8 times.
8 if (values.shape(0) != keys.shape(0)) {
1178 LOG(ERROR) << "EmbWriteAsync row mismatch: keys=" << keys.shape(0)
1179 << ", values=" << values.shape(0);
1180 return 0;
1181 }
1182
2/4
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 8 times.
8 if (values.shape(1) <= 0) {
1183 LOG(ERROR) << "EmbWriteAsync invalid embedding dim: " << values.shape(1);
1184 return 0;
1185 }
1186
1187
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 const uint64_t* key_data = keys.data_as<uint64_t>();
1188
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 const float* value_data = values.data_as<float>();
1189
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 int64_t key_count = keys.shape(0);
1190
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 int64_t emb_dim = values.shape(1);
1191
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
8 if (key_count == 0) {
1192 return 0;
1193 }
1194
1195 8 uint64_t prewrite_id = next_prewrite_id_++;
1196 8 int request_num =
1197 8 (static_cast<int>(key_count) + MAX_PARAMETER_BATCH_BRPC - 1) /
1198 MAX_PARAMETER_BATCH_BRPC;
1199
1200
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 auto it = prewrite_batches_.emplace(prewrite_id, request_num).first;
1201 8 struct BrpcPrewriteBatch* pb = &it->second;
1202
1203
1/2
✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
8 recstoreps_brpc::ParameterService_Stub stub(channel_.get());
1204
2/2
✓ Branch 0 taken 8 times.
✓ Branch 1 taken 8 times.
16 for (int start = 0, index = 0; start < key_count;
1205 8 start += MAX_PARAMETER_BATCH_BRPC, ++index) {
1206 int key_size =
1207 8 std::min(static_cast<int>(key_count - start), MAX_PARAMETER_BATCH_BRPC);
1208 8 pb->key_sizes_[index] = key_size;
1209
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 pb->controllers_[index] = std::make_unique<brpc::Controller>();
1210
1211
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 ParameterCompressor compressor;
1212
2/2
✓ Branch 0 taken 96 times.
✓ Branch 1 taken 8 times.
104 for (int i = 0; i < key_size; ++i) {
1213 96 int64_t row = start + i;
1214 96 ParameterPack parameter_pack;
1215 96 parameter_pack.key = key_data[row];
1216 96 parameter_pack.dim = emb_dim;
1217 96 parameter_pack.emb_data = value_data + row * emb_dim;
1218
1/2
✓ Branch 1 taken 96 times.
✗ Branch 2 not taken.
96 compressor.AddItem(parameter_pack, nullptr);
1219 }
1220
1221
1/2
✓ Branch 4 taken 8 times.
✗ Branch 5 not taken.
8 compressor.AppendToIOBuf(&pb->controllers_[index]->request_attachment());
1222
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 google::protobuf::Closure* done = brpc::NewCallback(OnPrewriteDone, pb);
1223
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 stub.PutParameter(
1224 8 pb->controllers_[index].get(),
1225 8 &pb->requests_[index],
1226 8 &pb->responses_[index],
1227 done);
1228 8 }
1229
1230 8 return prewrite_id;
1231 8 }
1232
1233 bool BRPCParameterClient::IsWriteDone(uint64_t write_id) {
1234 auto it = prewrite_batches_.find(write_id);
1235 if (it == prewrite_batches_.end()) {
1236 LOG(ERROR) << "Invalid prewrite_id: " << write_id;
1237 return false;
1238 }
1239 auto& pb = it->second;
1240 return pb.completed_count_ == pb.batch_size_;
1241 }
1242
1243 8 void BRPCParameterClient::WaitForWrite(uint64_t write_id) {
1244
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 auto it = prewrite_batches_.find(write_id);
1245
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 8 times.
8 if (it == prewrite_batches_.end()) {
1246 LOG(ERROR) << "Invalid prewrite_id: " << write_id;
1247 return;
1248 }
1249 8 auto& pb = it->second;
1250
2/2
✓ Branch 0 taken 8 times.
✓ Branch 1 taken 8 times.
16 for (int i = 0; i < pb.batch_size_; ++i) {
1251
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 8 times.
8 if (!pb.controllers_[i]) {
1252 continue;
1253 }
1254
2/4
✓ Branch 3 taken 8 times.
✗ Branch 4 not taken.
✓ Branch 6 taken 8 times.
✗ Branch 7 not taken.
8 brpc::Join(pb.controllers_[i]->call_id());
1255
2/4
✓ Branch 3 taken 8 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 8 times.
8 if (pb.controllers_[i]->Failed()) {
1256 LOG(ERROR) << "Async PutParameter failed: "
1257 << pb.controllers_[i]->ErrorText();
1258 }
1259 }
1260 8 pb.completed_count_ = pb.batch_size_;
1261
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 prewrite_batches_.erase(it);
1262 }
1263
1264 // Register BRPCParameterClient with the factory
1265 using BasePSClient = recstore::BasePSClient;
1266 FACTORY_REGISTER(BasePSClient, brpc, BRPCParameterClient, json);
1267