GCC Code Coverage Report


Directory: src/
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 54.3% 313 / 0 / 576
Functions: 68.0% 17 / 0 / 25
Branches: 23.6% 225 / 0 / 955

ps/grpc/grpc_ps_client.cpp
Line Branch Exec Source
1 #include "grpc_ps_client.h"
2
3 #include <fmt/core.h>
4 #include <grpcpp/grpcpp.h>
5
6 #include <cstdint>
7 #include <cstring>
8 #include <future>
9 #include <string>
10 #include <vector>
11
12 #include "base/array.h"
13 #include "base/factory.h"
14 #include "base/flatc.h"
15 #include "base/log.h"
16 #include "base/timer.h"
17 #include "ps/base/parameters.h"
18 #include "ps.grpc.pb.h"
19 #include "ps.pb.h"
20
21 #ifdef ENABLE_PERF_REPORT
22 # include "base/report/report_client.h"
23 # include <chrono>
24 #endif
25
26 using grpc::Channel;
27 using grpc::ClientAsyncResponseReader;
28 using grpc::ClientContext;
29 using grpc::Status;
30 using recstoreps::CommandRequest;
31 using recstoreps::CommandResponse;
32 using recstoreps::GetParameterRequest;
33 using recstoreps::GetParameterResponse;
34 using recstoreps::InitEmbeddingTableRequest;
35 using recstoreps::InitEmbeddingTableResponse;
36 using recstoreps::PSCommand;
37 using recstoreps::PutParameterRequest;
38 using recstoreps::PutParameterResponse;
39 using recstoreps::UpdateParameterRequest;
40 using recstoreps::UpdateParameterResponse;
41
42 namespace {
43
44 198 void SetRpcDeadline(grpc::ClientContext* context, int timeout_ms = 15000) {
45
1/2
✓ Branch 1 taken 198 times.
✗ Branch 2 not taken.
198 context->set_deadline(
46
1/2
✓ Branch 3 taken 198 times.
✗ Branch 4 not taken.
198 std::chrono::system_clock::now() + std::chrono::milliseconds(timeout_ms));
47 198 }
48
49 int BuildUpdateBlocksFromFlat(
50 const base::ConstArray<uint64_t>& keys,
51 const float* grads,
52 int64_t num_rows,
53 int64_t embedding_dim,
54 ParameterCompressor* compressor) {
55 if (grads == nullptr) {
56 LOG(ERROR) << "UpdateParameterFlat grads pointer is null";
57 return -1;
58 }
59 if (num_rows < 0 || embedding_dim <= 0) {
60 LOG(ERROR) << "UpdateParameterFlat invalid shape: rows=" << num_rows
61 << " dim=" << embedding_dim;
62 return -1;
63 }
64 if (keys.Size() != static_cast<size_t>(num_rows)) {
65 LOG(ERROR) << "UpdateParameterFlat keys/grads size mismatch: "
66 << keys.Size() << " vs " << num_rows;
67 return -1;
68 }
69
70 for (int64_t i = 0; i < num_rows; ++i) {
71 ParameterPack pack;
72 pack.key = keys[static_cast<size_t>(i)];
73 pack.dim = embedding_dim;
74 pack.emb_data = grads + i * embedding_dim;
75 compressor->AddItem(pack, nullptr);
76 }
77 return 0;
78 }
79
80 } // namespace
81
82 DEFINE_int32(get_parameter_threads, 4, "get clients per shard");
83 DEFINE_bool(parameter_client_random_init, false, "");
84
85 // New constructor that takes JSON config.
86 /*
87 Example: load config from file
88 std::ifstream config_file(FLAGS_config_path);
89 nlohmann::json ex;
90 config_file >> ex;
91 json client_config = ex["client"];
92
93 */
94 38 GRPCParameterClient::GRPCParameterClient(json config)
95
1/2
✓ Branch 2 taken 38 times.
✗ Branch 3 not taken.
38 : recstore::BasePSClient(config) {
96 // Extract fields from JSON config
97
1/2
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
38 host_ = config.value("host", "localhost");
98
1/2
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
38 port_ = config.value("port", 15000);
99
1/2
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
38 shard_ = config.value("shard", 0);
100 38 nr_clients_ = FLAGS_get_parameter_threads;
101 38 Initialize();
102
103
1/2
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
38 grpc::ChannelArguments args;
104
1/2
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
38 args.SetMaxReceiveMessageSize(-1);
105
1/2
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
38 args.SetMaxSendMessageSize(-1);
106
107 38 channel_ = grpc::CreateCustomChannel(
108
2/4
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 38 times.
✗ Branch 5 not taken.
114 fmt::format("{}:{}", host_, port_),
109
1/2
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
76 grpc::InsecureChannelCredentials(),
110 38 args);
111
2/4
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 38 times.
✗ Branch 5 not taken.
38 auto* raw_cq = new grpc::CompletionQueue();
112 38 cq_.reset(raw_cq);
113
114
2/2
✓ Branch 0 taken 152 times.
✓ Branch 1 taken 38 times.
190 for (int i = 0; i < nr_clients_; i++) {
115
1/2
✓ Branch 2 taken 152 times.
✗ Branch 3 not taken.
152 stubs_.push_back(nullptr);
116
1/2
✓ Branch 3 taken 152 times.
✗ Branch 4 not taken.
152 stubs_[i] = recstoreps::ParameterService::NewStub(channel_);
117
4/8
✓ Branch 1 taken 152 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 152 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 152 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 152 times.
✗ Branch 11 not taken.
152 LOG(INFO) << "Init PS Client Shard " << i;
118 }
119 38 }
120
121 // Legacy constructor for backward compatibility
122 2 GRPCParameterClient::GRPCParameterClient(
123 2 const std::string& host, int port, int shard)
124 : recstore::BasePSClient(
125
14/28
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 9 taken 2 times.
✗ Branch 10 not taken.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✓ Branch 19 taken 2 times.
✗ Branch 20 not taken.
✓ Branch 22 taken 6 times.
✓ Branch 23 taken 2 times.
✓ Branch 25 taken 4 times.
✓ Branch 26 taken 2 times.
✓ Branch 28 taken 4 times.
✓ Branch 29 taken 2 times.
✓ Branch 31 taken 4 times.
✓ Branch 32 taken 2 times.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
22 json{{"host", host}, {"port", port}, {"shard", shard}}),
126 2 host_(host),
127 2 port_(port),
128 2 shard_(shard),
129
2/4
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 2 times.
✗ Branch 6 not taken.
6 nr_clients_(FLAGS_get_parameter_threads) {
130 2 Initialize();
131
132
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 grpc::ChannelArguments args;
133
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 args.SetMaxReceiveMessageSize(-1);
134
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 args.SetMaxSendMessageSize(-1);
135
136 2 channel_ = grpc::CreateCustomChannel(
137
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
4 fmt::format("{}:{}", host, port),
138
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
4 grpc::InsecureChannelCredentials(),
139 2 args);
140
2/4
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
2 auto* raw_cq = new grpc::CompletionQueue();
141 2 cq_.reset(raw_cq);
142
143
2/2
✓ Branch 0 taken 8 times.
✓ Branch 1 taken 2 times.
10 for (int i = 0; i < nr_clients_; i++) {
144
1/2
✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
8 stubs_.push_back(nullptr);
145
1/2
✓ Branch 3 taken 8 times.
✗ Branch 4 not taken.
8 stubs_[i] = recstoreps::ParameterService::NewStub(channel_);
146
4/8
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 8 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 8 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 11 not taken.
8 LOG(INFO) << "Init PS Client Shard " << i;
147 }
148 2 }
149
150 int GRPCParameterClient::GetParameter(const base::ConstArray<uint64_t>& keys,
151 float* values) {
152 #ifdef ENABLE_PERF_REPORT
153 auto start_time = std::chrono::high_resolution_clock::now();
154 #endif
155
156 if (FLAGS_parameter_client_random_init) {
157 CHECK(0) << "todo implement";
158 return true;
159 }
160
161 get_param_key_sizes_.clear();
162 get_param_status_.clear();
163 get_param_requests_.clear();
164 get_param_responses_.clear();
165 get_param_resonse_readers_.clear();
166 get_param_contexts_.clear();
167
168 int request_num =
169 (keys.Size() + MAX_PARAMETER_BATCH - 1) / MAX_PARAMETER_BATCH;
170 get_param_status_.resize(request_num);
171 get_param_requests_.resize(request_num);
172 get_param_responses_.resize(request_num);
173 get_param_contexts_.resize(request_num);
174
175 for (int start = 0, index = 0; start < keys.Size();
176 start += MAX_PARAMETER_BATCH, ++index) {
177 int key_size = std::min((int)(keys.Size() - start), MAX_PARAMETER_BATCH);
178 get_param_key_sizes_.emplace_back(key_size);
179 auto& status = get_param_status_[index];
180 auto& request = get_param_requests_[index];
181 auto& response = get_param_responses_[index];
182 request.set_keys(reinterpret_cast<const char*>(&keys[start]),
183 sizeof(uint64_t) * key_size);
184 // rpc
185 // grpc::ClientContext context;
186 if (!get_param_contexts_[index]) {
187 get_param_contexts_[index] = std::make_unique<grpc::ClientContext>();
188 }
189 std::unique_ptr<ClientAsyncResponseReader<GetParameterResponse>> rpc =
190 stubs_[0]->AsyncGetParameter(
191 get_param_contexts_[index].get(), request, cq_.get());
192 rpc->Finish(&response, &status, reinterpret_cast<void*>(index));
193 }
194 int get = 0;
195 while (get != request_num) {
196 void* got_tag;
197 bool ok = false;
198 cq_->Next(&got_tag, &ok);
199 if (!ok) {
200 LOG(ERROR) << "error";
201 }
202 get++;
203 }
204 #ifdef ENABLE_PERF_REPORT
205 auto after_rpc_time = std::chrono::high_resolution_clock::now();
206 auto rpc_duration =
207 std::chrono::duration_cast<std::chrono::microseconds>(
208 after_rpc_time - start_time)
209 .count();
210 double start_us_for_rpc =
211 std::chrono::duration_cast<std::chrono::microseconds>(
212 start_time.time_since_epoch())
213 .count();
214 std::string report_id_for_rpc =
215 "grpc_client::GetParameter|" +
216 std::to_string(static_cast<uint64_t>(start_us_for_rpc));
217 report("embread_stages",
218 report_id_for_rpc.c_str(),
219 "rpc_duration_us",
220 static_cast<double>(rpc_duration));
221 #endif
222 size_t get_embedding_acc = 0;
223 int old_dimension = -1;
224
225 for (int i = 0; i < get_param_responses_.size(); ++i) {
226 auto& response = get_param_responses_[i];
227 int key_size = get_param_key_sizes_[i];
228 auto parameters = reinterpret_cast<const ParameterCompressReader*>(
229 response.parameter_value().data());
230
231 if (parameters->size != key_size) {
232 LOG(ERROR) << "GetParameter error: " << parameters->size << " vs "
233 << key_size;
234 return false;
235 }
236
237 for (int index = 0; index < parameters->item_size(); ++index) {
238 auto item = parameters->item(index);
239 if (item->dim != 0) {
240 if (old_dimension == -1)
241 old_dimension = item->dim;
242 CHECK_EQ(item->dim, old_dimension);
243 std::copy_n(
244 item->embedding, item->dim, values + item->dim * get_embedding_acc);
245 } else {
246 RECSTORE_LOG_EVERY_MS(ERROR, 2000)
247 << "error; not find key " << keys[get_embedding_acc] << " in ps";
248 }
249 get_embedding_acc++;
250 }
251 }
252
253 #ifdef ENABLE_PERF_REPORT
254 auto end_time = std::chrono::high_resolution_clock::now();
255 auto duration =
256 std::chrono::duration_cast<std::chrono::microseconds>(
257 end_time - start_time)
258 .count();
259 double start_us =
260 std::chrono::duration_cast<std::chrono::microseconds>(
261 start_time.time_since_epoch())
262 .count();
263
264 auto deserialize_duration =
265 std::chrono::duration_cast<std::chrono::microseconds>(
266 end_time - after_rpc_time)
267 .count();
268
269 report("embread_stages",
270 "grpc_client::GetParameter",
271 "deserialize_duration_us",
272 static_cast<double>(deserialize_duration));
273
274 report("embread_stages",
275 "grpc_client::GetParameter",
276 "duration_us",
277 static_cast<double>(duration));
278
279 report("embread_stages",
280 "grpc_client::GetParameter",
281 "request_size",
282 static_cast<double>(keys.Size()));
283
284 FlameGraphData grpc_client_data = {
285 "grpc_ps_client::GetParameter",
286 start_us,
287 1, // level
288 static_cast<double>(duration),
289 static_cast<double>(duration)};
290
291 std::string unique_id = "embread_debug";
292 report_flame_graph("emb_read_flame_map", unique_id.c_str(), grpc_client_data);
293 #endif
294
295 return true;
296 }
297
298 92 int GRPCParameterClient::GetParameter(const base::ConstArray<uint64_t>& keys,
299 std::vector<std::vector<float>>* values) {
300 #ifdef ENABLE_PERF_REPORT
301 auto start_time = std::chrono::high_resolution_clock::now();
302 #endif
303
304
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 92 times.
92 if (FLAGS_parameter_client_random_init) {
305 values->clear();
306 values->reserve(keys.Size());
307 for (size_t i = 0; i < keys.Size(); i++)
308 values->emplace_back(std::vector<float>(128, 0.1));
309
310 return true;
311 }
312
313 92 values->clear();
314 92 get_param_key_sizes_.clear();
315 92 get_param_status_.clear();
316 92 get_param_requests_.clear();
317 92 get_param_responses_.clear();
318 92 get_param_resonse_readers_.clear();
319 92 get_param_contexts_.clear();
320
321 92 values->reserve(keys.Size());
322
323 int request_num =
324 92 (keys.Size() + MAX_PARAMETER_BATCH - 1) / MAX_PARAMETER_BATCH;
325
326 92 get_param_status_.resize(request_num);
327 92 get_param_requests_.resize(request_num);
328 92 get_param_responses_.resize(request_num);
329 92 get_param_contexts_.resize(request_num);
330
331
2/2
✓ Branch 1 taken 92 times.
✓ Branch 2 taken 92 times.
184 for (int start = 0, index = 0; start < keys.Size();
332 92 start += MAX_PARAMETER_BATCH, ++index) {
333 92 int key_size = std::min((int)(keys.Size() - start), MAX_PARAMETER_BATCH);
334
1/2
✓ Branch 1 taken 92 times.
✗ Branch 2 not taken.
92 get_param_key_sizes_.emplace_back(key_size);
335 92 auto& status = get_param_status_[index];
336 92 auto& request = get_param_requests_[index];
337 92 auto& response = get_param_responses_[index];
338 92 request.set_keys(reinterpret_cast<const char*>(&keys[start]),
339 92 sizeof(uint64_t) * key_size);
340 // rpc
341 // grpc::ClientContext context;
342
1/2
✓ Branch 2 taken 92 times.
✗ Branch 3 not taken.
92 if (!get_param_contexts_[index]) {
343
1/2
✓ Branch 1 taken 92 times.
✗ Branch 2 not taken.
92 get_param_contexts_[index] = std::make_unique<grpc::ClientContext>();
344 }
345
2/4
✓ Branch 5 taken 92 times.
✗ Branch 6 not taken.
✓ Branch 8 taken 92 times.
✗ Branch 9 not taken.
184 get_param_resonse_readers_.emplace_back(stubs_[0]->AsyncGetParameter(
346 92 get_param_contexts_[index].get(), request, cq_.get()));
347 92 auto& rpc = get_param_resonse_readers_.back();
348 // GetParameter(&context, request, &response);
349
1/2
✓ Branch 2 taken 92 times.
✗ Branch 3 not taken.
92 rpc->Finish(&response, &status, reinterpret_cast<void*>(index));
350 }
351
352 92 int get = 0;
353
2/2
✓ Branch 0 taken 92 times.
✓ Branch 1 taken 92 times.
184 while (get != request_num) {
354 void* got_tag;
355 92 bool ok = false;
356
1/2
✓ Branch 2 taken 92 times.
✗ Branch 3 not taken.
92 cq_->Next(&got_tag, &ok);
357
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 92 times.
92 if (unlikely(!ok)) {
358 LOG(ERROR) << "error";
359 }
360 92 get++;
361 }
362
363 #ifdef ENABLE_PERF_REPORT
364 auto after_rpc_time = std::chrono::high_resolution_clock::now();
365 auto rpc_duration =
366 std::chrono::duration_cast<std::chrono::microseconds>(
367 after_rpc_time - start_time)
368 .count();
369 double start_us_for_rpc =
370 std::chrono::duration_cast<std::chrono::microseconds>(
371 start_time.time_since_epoch())
372 .count();
373 std::string report_id_for_rpc =
374 "grpc_client::GetParameter|" +
375 std::to_string(static_cast<uint64_t>(start_us_for_rpc));
376 report("embread_stages",
377 report_id_for_rpc.c_str(),
378 "rpc_duration_us",
379 static_cast<double>(rpc_duration));
380 #endif
381
382
2/2
✓ Branch 1 taken 92 times.
✓ Branch 2 taken 92 times.
184 for (int i = 0; i < get_param_responses_.size(); ++i) {
383 92 auto& response = get_param_responses_[i];
384 92 int key_size = get_param_key_sizes_[i];
385 auto parameters = reinterpret_cast<const ParameterCompressReader*>(
386 92 response.parameter_value().data());
387
388
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 92 times.
92 if (unlikely(parameters->size != key_size)) {
389 LOG(ERROR) << "GetParameter error: " << parameters->size << " vs "
390 << key_size;
391 return false;
392 }
393
394
2/2
✓ Branch 1 taken 380 times.
✓ Branch 2 taken 92 times.
472 for (int index = 0; index < parameters->item_size(); ++index) {
395 380 auto item = parameters->item(index);
396
2/2
✓ Branch 0 taken 368 times.
✓ Branch 1 taken 12 times.
380 if (item->dim != 0) {
397
1/2
✓ Branch 1 taken 368 times.
✗ Branch 2 not taken.
368 values->emplace_back(
398
1/2
✓ Branch 2 taken 368 times.
✗ Branch 3 not taken.
736 std::vector<float>(item->embedding, item->embedding + item->dim));
399 } else {
400
2/4
✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 12 times.
✗ Branch 6 not taken.
12 values->emplace_back(std::vector<float>(0));
401 }
402 }
403 }
404
405 #ifdef ENABLE_PERF_REPORT
406 auto end_time = std::chrono::high_resolution_clock::now();
407 auto duration =
408 std::chrono::duration_cast<std::chrono::microseconds>(
409 end_time - start_time)
410 .count();
411 double start_us =
412 std::chrono::duration_cast<std::chrono::microseconds>(
413 start_time.time_since_epoch())
414 .count();
415
416 auto deserialize_duration =
417 std::chrono::duration_cast<std::chrono::microseconds>(
418 end_time - after_rpc_time)
419 .count();
420
421 report("embread_stages",
422 "grpc_client::GetParameter",
423 "deserialize_duration_us",
424 static_cast<double>(deserialize_duration));
425
426 report("embread_stages",
427 "grpc_client::GetParameter",
428 "duration_us",
429 static_cast<double>(duration));
430
431 report("embread_stages",
432 "grpc_client::GetParameter",
433 "request_size",
434 static_cast<double>(keys.Size()));
435
436 FlameGraphData grpc_client_data = {
437 "grpc_ps_client::GetParameter",
438 start_us,
439 1, // level
440 static_cast<double>(duration),
441 static_cast<double>(duration)};
442
443 std::string unique_id = "embread_debug";
444 report_flame_graph("emb_read_flame_map", unique_id.c_str(), grpc_client_data);
445 #endif
446
447 92 return true;
448 }
449
450 // return prefetch id
451 uint64_t
452 112 GRPCParameterClient::PrefetchParameter(const base::ConstArray<uint64_t>& keys) {
453 int request_num =
454 112 (keys.Size() + MAX_PARAMETER_BATCH - 1) / MAX_PARAMETER_BATCH;
455
456
1/2
✓ Branch 1 taken 112 times.
✗ Branch 2 not taken.
112 struct PrefetchBatch pb(request_num);
457
458
2/2
✓ Branch 1 taken 112 times.
✓ Branch 2 taken 112 times.
224 for (int start = 0, index = 0; start < keys.Size();
459 112 start += MAX_PARAMETER_BATCH, ++index) {
460 112 int key_size = std::min((int)(keys.Size() - start), MAX_PARAMETER_BATCH);
461 112 pb.key_sizes_[index] = key_size;
462 112 auto& status = pb.status_[index];
463
1/2
✓ Branch 2 taken 112 times.
✗ Branch 3 not taken.
112 if (!pb.contexts_[index]) {
464
1/2
✓ Branch 1 taken 112 times.
✗ Branch 2 not taken.
112 pb.contexts_[index] = std::make_unique<grpc::ClientContext>();
465 }
466 112 auto& request = pb.requests_[index];
467 112 auto& response = pb.responses_[index];
468 112 request.set_keys(reinterpret_cast<const char*>(&keys[start]),
469 112 sizeof(uint64_t) * key_size);
470 // rpc
471 // grpc::ClientContext context;
472
2/4
✓ Branch 5 taken 112 times.
✗ Branch 6 not taken.
✓ Branch 8 taken 112 times.
✗ Branch 9 not taken.
224 pb.response_readers_.emplace_back(stubs_[0]->AsyncGetParameter(
473 112 pb.contexts_[index].get(), request, pb.cqs_.get()));
474 112 auto& rpc = pb.response_readers_.back();
475 // GetParameter(&context, request, &response);
476
1/2
✓ Branch 2 taken 112 times.
✗ Branch 3 not taken.
112 rpc->Finish(&response, &status, reinterpret_cast<void*>(index));
477 }
478 112 uint64_t prefetch_id = 0;
479 {
480
1/2
✓ Branch 1 taken 112 times.
✗ Branch 2 not taken.
112 std::lock_guard<std::mutex> lk(prefetch_mu_);
481 112 prefetch_id = next_prefetch_id_++;
482
1/2
✓ Branch 2 taken 112 times.
✗ Branch 3 not taken.
112 prefetch_batches_.emplace(prefetch_id, std::move(pb));
483 112 }
484
485 112 return prefetch_id;
486 112 }
487
488 4 bool GRPCParameterClient::IsPrefetchDone(uint64_t prefetch_id) {
489
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 std::lock_guard<std::mutex> lk(prefetch_mu_);
490
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 auto it = prefetch_batches_.find(prefetch_id);
491
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 4 times.
4 if (it == prefetch_batches_.end()) {
492 LOG(ERROR) << "Invalid prefetch_id: " << prefetch_id;
493 return false;
494 }
495 4 auto& pb = it->second;
496 4 int request_num = pb.batch_size_;
497 4 int get = 0;
498
499
1/2
✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
4 if (pb.completed_count_ == pb.batch_size_) {
500 4 return true;
501 }
502
503 void* got_tag = nullptr;
504 bool ok = false;
505 auto deadline =
506 std::chrono::system_clock::now() + std::chrono::milliseconds(0);
507 for (;;) {
508 auto status = pb.cqs_->AsyncNext(&got_tag, &ok, deadline);
509 if (status == grpc::CompletionQueue::NextStatus::GOT_EVENT) {
510 if (unlikely(!ok)) {
511 LOG(ERROR) << "CompletionQueue returned not ok for prefetch";
512 }
513 pb.completed_count_++;
514 if (pb.completed_count_ == pb.batch_size_)
515 break;
516 deadline =
517 std::chrono::system_clock::now() + std::chrono::milliseconds(0);
518 continue;
519 } else if (status == grpc::CompletionQueue::NextStatus::TIMEOUT) {
520 break;
521 } else {
522 LOG(ERROR) << "CompletionQueue shutdown during prefetch";
523 break;
524 }
525 }
526 return (pb.completed_count_ == pb.batch_size_);
527 4 }
528
529 212 void GRPCParameterClient::WaitForPrefetch(uint64_t prefetch_id) {
530
1/2
✓ Branch 1 taken 212 times.
✗ Branch 2 not taken.
212 std::lock_guard<std::mutex> lk(prefetch_mu_);
531
1/2
✓ Branch 1 taken 212 times.
✗ Branch 2 not taken.
212 auto it = prefetch_batches_.find(prefetch_id);
532
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 212 times.
212 if (it == prefetch_batches_.end()) {
533 LOG(ERROR) << "Invalid prefetch_id: " << prefetch_id;
534 return;
535 }
536 212 auto& pb = it->second;
537 212 void* got_tag = nullptr;
538 212 bool ok = false;
539 212 int idle_rounds = 0;
540 212 constexpr auto kPollInterval = std::chrono::milliseconds(200);
541 212 constexpr int kMaxIdleRounds = 150; // 30s
542
2/2
✓ Branch 0 taken 112 times.
✓ Branch 1 taken 212 times.
324 while (pb.completed_count_ < pb.batch_size_) {
543
1/2
✓ Branch 2 taken 112 times.
✗ Branch 3 not taken.
112 auto deadline = std::chrono::system_clock::now() + kPollInterval;
544
1/2
✓ Branch 2 taken 112 times.
✗ Branch 3 not taken.
112 auto status = pb.cqs_->AsyncNext(&got_tag, &ok, deadline);
545
1/2
✓ Branch 0 taken 112 times.
✗ Branch 1 not taken.
112 if (status == grpc::CompletionQueue::NextStatus::GOT_EVENT) {
546 112 idle_rounds = 0;
547
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 112 times.
112 if (unlikely(!ok)) {
548 LOG(ERROR) << "CompletionQueue returned not ok for prefetch";
549 }
550 112 pb.completed_count_++;
551 112 continue;
552 }
553 if (status == grpc::CompletionQueue::NextStatus::TIMEOUT) {
554 idle_rounds++;
555 if (idle_rounds >= kMaxIdleRounds) {
556 LOG(ERROR) << "WaitForPrefetch timed out for prefetch_id "
557 << prefetch_id << ", completed " << pb.completed_count_
558 << "/" << pb.batch_size_;
559 break;
560 }
561 continue;
562 }
563 if (status == grpc::CompletionQueue::NextStatus::SHUTDOWN) {
564 LOG(ERROR) << "CompletionQueue shutdown while waiting prefetch";
565 break;
566 }
567 }
568
1/2
✓ Branch 1 taken 212 times.
✗ Branch 2 not taken.
212 }
569
570 112 bool GRPCParameterClient::GetPrefetchResult(
571 uint64_t prefetch_id, std::vector<std::vector<float>>* values) {
572
1/2
✓ Branch 1 taken 112 times.
✗ Branch 2 not taken.
112 std::lock_guard<std::mutex> lk(prefetch_mu_);
573
1/2
✓ Branch 1 taken 112 times.
✗ Branch 2 not taken.
112 auto it = prefetch_batches_.find(prefetch_id);
574
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 112 times.
112 if (it == prefetch_batches_.end()) {
575 LOG(ERROR) << "Invalid prefetch_id: " << prefetch_id;
576 return false;
577 }
578 112 auto& pb = it->second;
579 112 int request_num = pb.batch_size_;
580
581 112 values->clear();
582 112 int keys_size = 0;
583
2/2
✓ Branch 4 taken 112 times.
✓ Branch 5 taken 112 times.
224 for (const auto& size : pb.key_sizes_) {
584 112 keys_size += size;
585 }
586
1/2
✓ Branch 1 taken 112 times.
✗ Branch 2 not taken.
112 values->reserve(keys_size);
587
588
2/2
✓ Branch 0 taken 112 times.
✓ Branch 1 taken 112 times.
224 for (int i = 0; i < request_num; ++i) {
589 112 auto& response = pb.responses_[i];
590 112 int key_size = pb.key_sizes_[i];
591 auto parameters = reinterpret_cast<const ParameterCompressReader*>(
592
1/2
✓ Branch 1 taken 112 times.
✗ Branch 2 not taken.
112 response.parameter_value().data());
593
594
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 112 times.
112 if (unlikely(parameters->size != key_size)) {
595 LOG(ERROR) << "GetParameter error: " << parameters->size << " vs "
596 << key_size;
597 return false;
598 }
599
600
2/2
✓ Branch 1 taken 520 times.
✓ Branch 2 taken 112 times.
632 for (int index = 0; index < parameters->item_size(); ++index) {
601
1/2
✓ Branch 1 taken 520 times.
✗ Branch 2 not taken.
520 auto item = parameters->item(index);
602
1/2
✓ Branch 0 taken 520 times.
✗ Branch 1 not taken.
520 if (item->dim != 0) {
603
1/2
✓ Branch 1 taken 520 times.
✗ Branch 2 not taken.
520 values->emplace_back(
604
1/2
✓ Branch 2 taken 520 times.
✗ Branch 3 not taken.
1040 std::vector<float>(item->embedding, item->embedding + item->dim));
605 } else {
606 values->emplace_back(std::vector<float>(0));
607 }
608 }
609 }
610
611 112 return true;
612 112 }
613
614 bool GRPCParameterClient::GetPrefetchResultFlat(
615 uint64_t prefetch_id,
616 std::vector<float>* values,
617 int64_t* num_rows,
618 int64_t embedding_dim) {
619 std::lock_guard<std::mutex> lk(prefetch_mu_);
620 auto it = prefetch_batches_.find(prefetch_id);
621 if (it == prefetch_batches_.end()) {
622 LOG(ERROR) << "Invalid prefetch_id: " << prefetch_id;
623 return false;
624 }
625 if (values == nullptr || num_rows == nullptr) {
626 LOG(ERROR) << "GetPrefetchResultFlat output pointer is null";
627 return false;
628 }
629
630 auto& pb = it->second;
631 int request_num = pb.batch_size_;
632 int total_keys = 0;
633 for (const auto& size : pb.key_sizes_) {
634 total_keys += size;
635 }
636
637 *num_rows = static_cast<int64_t>(total_keys);
638 values->assign(
639 static_cast<size_t>(*num_rows) * static_cast<size_t>(embedding_dim),
640 0.0f);
641
642 size_t row_offset = 0;
643 for (int i = 0; i < request_num; ++i) {
644 auto& response = pb.responses_[i];
645 int key_size = pb.key_sizes_[i];
646 auto parameters = reinterpret_cast<const ParameterCompressReader*>(
647 response.parameter_value().data());
648
649 if (unlikely(parameters->size != key_size)) {
650 LOG(ERROR) << "GetParameter error: " << parameters->size << " vs "
651 << key_size;
652 return false;
653 }
654
655 for (int index = 0; index < parameters->item_size();
656 ++index, ++row_offset) {
657 auto item = parameters->item(index);
658 if (item->dim != 0) {
659 const int64_t copy_d =
660 std::min<int64_t>(embedding_dim, static_cast<int64_t>(item->dim));
661 std::memcpy(values->data() + row_offset * embedding_dim,
662 item->embedding,
663 static_cast<size_t>(copy_d) * sizeof(float));
664 }
665 }
666 }
667
668 return true;
669 }
670
671 34 bool GRPCParameterClient::ClearPS() {
672
1/2
✓ Branch 1 taken 34 times.
✗ Branch 2 not taken.
34 CommandRequest request;
673
1/2
✓ Branch 1 taken 34 times.
✗ Branch 2 not taken.
34 CommandResponse response;
674
1/2
✓ Branch 1 taken 34 times.
✗ Branch 2 not taken.
34 request.set_command(PSCommand::CLEAR_PS);
675
1/2
✓ Branch 1 taken 34 times.
✗ Branch 2 not taken.
34 grpc::ClientContext context;
676
1/2
✓ Branch 1 taken 34 times.
✗ Branch 2 not taken.
34 SetRpcDeadline(&context);
677
1/2
✓ Branch 3 taken 34 times.
✗ Branch 4 not taken.
34 grpc::Status status = stubs_[0]->Command(&context, request, &response);
678
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 34 times.
34 if (!status.ok()) {
679 LOG(ERROR) << "gRPC ClearPS failed: " << status.error_code() << " "
680 << status.error_message();
681 }
682 68 return status.ok();
683 34 }
684
685 // Read n bytes from the server. The server does not access storage;
686 // it generates data randomly instead.
687 6 bool GRPCParameterClient::LoadFakeData(int64_t n) {
688
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 CommandRequest request;
689
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 CommandResponse response;
690
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 request.set_command(PSCommand::LOAD_FAKE_DATA);
691
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 request.add_arg1(&n, sizeof(int64_t));
692
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 grpc::ClientContext context;
693
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 SetRpcDeadline(&context);
694
1/2
✓ Branch 3 taken 6 times.
✗ Branch 4 not taken.
6 grpc::Status status = stubs_[0]->Command(&context, request, &response);
695
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
6 if (!status.ok()) {
696 LOG(ERROR) << "gRPC LoadFakeData failed: " << status.error_code() << " "
697 << status.error_message();
698 return false;
699 }
700
2/4
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 6 times.
6 if (response.reply().size() != static_cast<size_t>(n)) {
701 LOG(ERROR) << "gRPC LoadFakeData reply size mismatch: expected " << n
702 << ", got " << response.reply().size();
703 return false;
704 }
705 6 return true;
706 6 }
707
708 // Write n bytes(random generated) into the server
709 6 bool GRPCParameterClient::DumpFakeData(int64_t n) {
710
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 CommandRequest request;
711
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 CommandResponse response;
712
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 request.set_command(PSCommand::DUMP_FAKE_DATA);
713
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 request.add_arg1(&n, sizeof(int64_t));
714
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 grpc::ClientContext context;
715
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 SetRpcDeadline(&context);
716
1/2
✓ Branch 3 taken 6 times.
✗ Branch 4 not taken.
6 grpc::Status status = stubs_[0]->Command(&context, request, &response);
717
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
6 if (!status.ok()) {
718 LOG(ERROR) << "gRPC DumpFakeData failed: " << status.error_code() << " "
719 << status.error_message();
720 return false;
721 }
722
3/6
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
6 if (response.reply() != "ok") {
723 LOG(ERROR) << "gRPC DumpFakeData unexpected reply: " << response.reply();
724 return false;
725 }
726 6 return true;
727 6 }
728
729 bool GRPCParameterClient::LoadCkpt(
730 const std::vector<std::string>& model_config_path,
731 const std::vector<std::string>& emb_file_path) {
732 CommandRequest request;
733 CommandResponse response;
734 request.set_command(PSCommand::RELOAD_PS);
735
736 for (auto& each : model_config_path) {
737 request.add_arg1(each);
738 }
739 for (auto& each : emb_file_path) {
740 request.add_arg2(each);
741 }
742 grpc::ClientContext context;
743 SetRpcDeadline(&context, 30000);
744 grpc::Status status = stubs_[0]->Command(&context, request, &response);
745 return status.ok();
746 }
747
748 150 bool GRPCParameterClient::PutParameter(
749 const std::vector<uint64_t>& keys,
750 const std::vector<std::vector<float>>& values) {
751
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 150 times.
150 if (keys.size() != values.size()) {
752 LOG(ERROR) << "PutParameter keys/values size mismatch: " << keys.size()
753 << " vs " << values.size();
754 return false;
755 }
756
2/2
✓ Branch 1 taken 150 times.
✓ Branch 2 taken 150 times.
300 for (int start = 0, index = 0; start < keys.size();
757 150 start += MAX_PARAMETER_BATCH, ++index) {
758 150 int key_size = std::min((int)(keys.size() - start), MAX_PARAMETER_BATCH);
759
1/2
✓ Branch 1 taken 150 times.
✗ Branch 2 not taken.
150 PutParameterRequest request;
760
1/2
✓ Branch 1 taken 150 times.
✗ Branch 2 not taken.
150 PutParameterResponse response;
761
1/2
✓ Branch 1 taken 150 times.
✗ Branch 2 not taken.
150 ParameterCompressor compressor;
762 150 std::vector<std::string> blocks;
763
2/2
✓ Branch 0 taken 1250 times.
✓ Branch 1 taken 150 times.
1400 for (int i = start; i < start + key_size; i++) {
764 1250 auto each_key = keys[i];
765 1250 auto& embedding = values[i];
766 1250 ParameterPack parameter_pack;
767 1250 parameter_pack.key = each_key;
768 1250 parameter_pack.dim = embedding.size();
769 1250 parameter_pack.emb_data = embedding.data();
770
1/2
✓ Branch 1 taken 1250 times.
✗ Branch 2 not taken.
1250 compressor.AddItem(parameter_pack, &blocks);
771 }
772
1/2
✓ Branch 1 taken 150 times.
✗ Branch 2 not taken.
150 compressor.ToBlock(&blocks);
773
2/8
✓ Branch 4 taken 150 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 150 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
150 CHECK_EQ(blocks.size(), 1);
774
1/2
✓ Branch 1 taken 150 times.
✗ Branch 2 not taken.
150 request.mutable_parameter_value()->swap(blocks[0]);
775
1/2
✓ Branch 1 taken 150 times.
✗ Branch 2 not taken.
150 grpc::ClientContext context;
776
1/2
✓ Branch 1 taken 150 times.
✗ Branch 2 not taken.
150 SetRpcDeadline(&context);
777
1/2
✓ Branch 3 taken 150 times.
✗ Branch 4 not taken.
150 grpc::Status status = stubs_[0]->PutParameter(&context, request, &response);
778
1/2
✓ Branch 1 taken 150 times.
✗ Branch 2 not taken.
150 if (status.ok()) {
779 150 continue;
780 } else {
781 std::cout << status.error_code() << ": " << status.error_message()
782 << std::endl;
783 return false;
784 }
785
6/12
✓ Branch 1 taken 150 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 150 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 150 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 150 times.
✗ Branch 11 not taken.
✓ Branch 13 taken 150 times.
✗ Branch 14 not taken.
✓ Branch 16 taken 150 times.
✗ Branch 17 not taken.
900 }
786 150 return true;
787 }
788
789 2 int GRPCParameterClient::UpdateParameter(
790 const std::string& table_name,
791 const base::ConstArray<uint64_t>& keys,
792 const std::vector<std::vector<float>>* grads) {
793 #ifdef ENABLE_PERF_REPORT
794 auto start_time = std::chrono::high_resolution_clock::now();
795 const uint64_t trace_id = recstore::g_trace_id;
796 #endif
797
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
2 if (grads == nullptr) {
798 LOG(ERROR) << "UpdateParameter grads pointer is null";
799 return -1;
800 }
801
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
2 if (keys.Size() != grads->size()) {
802 LOG(ERROR) << "UpdateParameter keys/grads size mismatch: " << keys.Size()
803 << " vs " << grads->size();
804 return -1;
805 }
806
807
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 ParameterCompressor compressor;
808
2/2
✓ Branch 1 taken 4 times.
✓ Branch 2 taken 2 times.
6 for (size_t i = 0; i < keys.Size(); ++i) {
809 4 ParameterPack pack;
810 4 pack.key = keys[i];
811
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 pack.dim = grads->at(i).size();
812
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 pack.emb_data = grads->at(i).data();
813
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 compressor.AddItem(pack, nullptr);
814 }
815 #ifdef ENABLE_PERF_REPORT
816 auto serialize_done_time = std::chrono::high_resolution_clock::now();
817 #endif
818
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
2 if (keys.Size() == 0) {
819 LOG(WARNING) << "UpdateParameter no gradients to send";
820 return 0;
821 }
822
823
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 UpdateParameterRequest request;
824
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 UpdateParameterResponse response;
825 request.set_table_name(table_name);
826
2/4
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
2 compressor.ToBlock(request.mutable_gradients());
827
2/4
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
2 if (request.gradients().empty()) {
828 LOG(WARNING) << "UpdateParameter no serialized gradients payload";
829 return 0;
830 }
831
832
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 grpc::ClientContext context;
833
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 SetRpcDeadline(&context);
834 #ifdef ENABLE_PERF_REPORT
835 if (trace_id != 0) {
836 context.AddMetadata("x-recstore-trace-id", std::to_string(trace_id));
837 }
838 auto rpc_start_time = std::chrono::high_resolution_clock::now();
839 #endif
840 grpc::Status status =
841
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 stubs_[0]->UpdateParameter(&context, request, &response);
842 #ifdef ENABLE_PERF_REPORT
843 auto end_time = std::chrono::high_resolution_clock::now();
844 auto serialize_duration =
845 std::chrono::duration_cast<std::chrono::microseconds>(
846 serialize_done_time - start_time)
847 .count();
848 auto rpc_duration =
849 std::chrono::duration_cast<std::chrono::microseconds>(
850 end_time - rpc_start_time)
851 .count();
852 auto total_duration =
853 std::chrono::duration_cast<std::chrono::microseconds>(
854 end_time - start_time)
855 .count();
856 std::string stage_id =
857 "grpc_client::EmbUpdate|" +
858 std::to_string(
859 trace_id == 0
860 ? static_cast<uint64_t>(
861 std::chrono::duration_cast< std::chrono::microseconds>(
862 start_time.time_since_epoch())
863 .count())
864 : trace_id);
865 report("embupdate_stages",
866 stage_id.c_str(),
867 "client_serialize_us",
868 static_cast<double>(serialize_duration));
869 report("embupdate_stages",
870 stage_id.c_str(),
871 "client_rpc_us",
872 static_cast<double>(rpc_duration));
873 report("embupdate_stages",
874 stage_id.c_str(),
875 "client_total_us",
876 static_cast<double>(total_duration));
877 report("embupdate_stages",
878 stage_id.c_str(),
879 "client_request_size",
880 static_cast<double>(keys.Size()));
881 #endif
882
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
2 if (!status.ok()) {
883 LOG(ERROR) << "UpdateParameter RPC failed: " << status.error_message();
884 return -1;
885 }
886
2/4
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 return response.success() ? 0 : -1;
887 2 }
888
889 int GRPCParameterClient::UpdateParameterFlat(
890 const std::string& table_name,
891 const base::ConstArray<uint64_t>& keys,
892 const float* grads,
893 int64_t num_rows,
894 int64_t embedding_dim) {
895 #ifdef ENABLE_PERF_REPORT
896 auto start_time = std::chrono::high_resolution_clock::now();
897 const uint64_t trace_id = recstore::g_trace_id;
898 #endif
899 if (keys.Size() == 0) {
900 return 0;
901 }
902
903 ParameterCompressor compressor;
904 if (BuildUpdateBlocksFromFlat(
905 keys, grads, num_rows, embedding_dim, &compressor) != 0) {
906 return -1;
907 }
908 #ifdef ENABLE_PERF_REPORT
909 auto serialize_done_time = std::chrono::high_resolution_clock::now();
910 #endif
911 UpdateParameterRequest request;
912 UpdateParameterResponse response;
913 request.set_table_name(table_name);
914 compressor.ToBlock(request.mutable_gradients());
915 if (request.gradients().empty()) {
916 return 0;
917 }
918
919 grpc::ClientContext context;
920 SetRpcDeadline(&context);
921 #ifdef ENABLE_PERF_REPORT
922 if (trace_id != 0) {
923 context.AddMetadata("x-recstore-trace-id", std::to_string(trace_id));
924 }
925 auto rpc_start_time = std::chrono::high_resolution_clock::now();
926 #endif
927 grpc::Status status =
928 stubs_[0]->UpdateParameter(&context, request, &response);
929 #ifdef ENABLE_PERF_REPORT
930 auto end_time = std::chrono::high_resolution_clock::now();
931 auto serialize_duration =
932 std::chrono::duration_cast<std::chrono::microseconds>(
933 serialize_done_time - start_time)
934 .count();
935 auto rpc_duration =
936 std::chrono::duration_cast<std::chrono::microseconds>(
937 end_time - rpc_start_time)
938 .count();
939 auto total_duration =
940 std::chrono::duration_cast<std::chrono::microseconds>(
941 end_time - start_time)
942 .count();
943 std::string stage_id =
944 "grpc_client::EmbUpdate|" +
945 std::to_string(
946 trace_id == 0
947 ? static_cast<uint64_t>(
948 std::chrono::duration_cast< std::chrono::microseconds>(
949 start_time.time_since_epoch())
950 .count())
951 : trace_id);
952 report("embupdate_stages",
953 stage_id.c_str(),
954 "client_serialize_us",
955 static_cast<double>(serialize_duration));
956 report("embupdate_stages",
957 stage_id.c_str(),
958 "client_rpc_us",
959 static_cast<double>(rpc_duration));
960 report("embupdate_stages",
961 stage_id.c_str(),
962 "client_total_us",
963 static_cast<double>(total_duration));
964 report("embupdate_stages",
965 stage_id.c_str(),
966 "client_request_size",
967 static_cast<double>(num_rows));
968 report("embupdate_stages",
969 stage_id.c_str(),
970 "client_embedding_dim",
971 static_cast<double>(embedding_dim));
972 #endif
973 if (!status.ok()) {
974 LOG(ERROR) << "UpdateParameterFlat RPC failed: " << status.error_message();
975 return -1;
976 }
977 return response.success() ? 0 : -1;
978 }
979
980 20 int GRPCParameterClient::InitEmbeddingTable(
981 const std::string& table_name,
982 const recstore::EmbeddingTableConfig& config) {
983
1/2
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
20 InitEmbeddingTableRequest request;
984
1/2
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
20 InitEmbeddingTableResponse response;
985 request.set_table_name(table_name);
986
1/2
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
40 request.set_config_payload(config.Serialize());
987
988
1/2
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
20 grpc::ClientContext context;
989 grpc::Status status =
990
1/2
✓ Branch 3 taken 20 times.
✗ Branch 4 not taken.
20 stubs_[0]->InitEmbeddingTable(&context, request, &response);
991
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 20 times.
20 if (!status.ok()) {
992 LOG(ERROR) << "InitEmbeddingTable RPC failed: " << status.error_message();
993 return -1;
994 }
995
2/4
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 20 times.
✗ Branch 4 not taken.
20 return response.success() ? 0 : -1;
996 20 }
997
998 // BasePSClient pure virtual implementations
999 // int GRPCParameterClient::GetParameter(const base::ConstArray<uint64_t>& keys,
1000 // float* values) {
1001 // return GetParameter(ConstArray<uint64_t>(keys.Data(), keys.Size()), values)
1002 // ? 0 : -1;
1003 // }
1004
1005 int GRPCParameterClient::AsyncGetParameter(
1006 const base::ConstArray<uint64_t>& keys, float* values) {
1007 return GetParameter(keys, values);
1008 }
1009
1010 148 int GRPCParameterClient::PutParameter(
1011 const base::ConstArray<uint64_t>& keys,
1012 const std::vector<std::vector<float>>& values) {
1013
1/2
✓ Branch 5 taken 148 times.
✗ Branch 6 not taken.
148 std::vector<uint64_t> key_vec(keys.Data(), keys.Data() + keys.Size());
1014
1/2
✓ Branch 1 taken 148 times.
✗ Branch 2 not taken.
148 bool success = PutParameter(key_vec, values);
1015
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 148 times.
148 if (!success) {
1016 LOG(ERROR) << "PutParameter batch failed";
1017 }
1018
1/2
✓ Branch 0 taken 148 times.
✗ Branch 1 not taken.
296 return success ? 1 : 0;
1019 148 }
1020
1021 void GRPCParameterClient::Command(recstore::PSCommand command) {
1022 switch (command) {
1023 case recstore::PSCommand::CLEAR_PS:
1024 ClearPS();
1025 break;
1026 case recstore::PSCommand::RELOAD_PS:
1027
1028 LOG(WARNING) << "RELOAD_PS command requires additional parameters";
1029 break;
1030 case recstore::PSCommand::LOAD_FAKE_DATA: {
1031 int64_t fake_data = 1000;
1032 LoadFakeData(fake_data);
1033 } break;
1034 case recstore::PSCommand::DUMP_FAKE_DATA: {
1035 DumpFakeData(4096);
1036 } break;
1037 default:
1038 LOG(ERROR) << "Unknown PS command: " << static_cast<int>(command);
1039 break;
1040 }
1041 }
1042
1043 8 uint64_t GRPCParameterClient::EmbWriteAsync(
1044 const base::ConstArray<uint64_t>& keys,
1045 const std::vector<std::vector<float>>& values) {
1046 8 uint64_t prewrite_id = next_prewrite_id_++;
1047 int request_num =
1048 8 (keys.Size() + MAX_PARAMETER_BATCH - 1) / MAX_PARAMETER_BATCH;
1049
1050
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 struct PrewriteBatch pb(request_num);
1051
1052
2/2
✓ Branch 1 taken 8 times.
✓ Branch 2 taken 8 times.
16 for (int start = 0, index = 0; start < keys.Size();
1053 8 start += MAX_PARAMETER_BATCH, ++index) {
1054 8 int key_size = std::min((int)(keys.Size() - start), MAX_PARAMETER_BATCH);
1055 8 pb.key_sizes_[index] = key_size;
1056 8 auto& status = pb.status_[index];
1057
1/2
✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
8 if (!pb.contexts_[index]) {
1058
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 pb.contexts_[index] = std::make_unique<grpc::ClientContext>();
1059 }
1060 8 auto& request = pb.requests_[index];
1061 8 auto& response = pb.responses_[index];
1062
1063 // Pack key/embedding pairs
1064
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 ParameterCompressor compressor;
1065 8 std::vector<std::string> blocks;
1066
2/2
✓ Branch 0 taken 96 times.
✓ Branch 1 taken 8 times.
104 for (int i = start; i < start + key_size; i++) {
1067 96 auto each_key = keys[i];
1068 96 auto& embedding = values[i];
1069 96 ParameterPack parameter_pack;
1070 96 parameter_pack.key = each_key;
1071 96 parameter_pack.dim = embedding.size();
1072 96 parameter_pack.emb_data = embedding.data();
1073
1/2
✓ Branch 1 taken 96 times.
✗ Branch 2 not taken.
96 compressor.AddItem(parameter_pack, &blocks);
1074 }
1075
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 compressor.ToBlock(&blocks);
1076
2/8
✓ Branch 4 taken 8 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 8 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
8 CHECK_EQ(blocks.size(), 1);
1077
1078
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 request.mutable_parameter_value()->swap(blocks[0]);
1079
1080 // Issue async RPC
1081
2/4
✓ Branch 5 taken 8 times.
✗ Branch 6 not taken.
✓ Branch 8 taken 8 times.
✗ Branch 9 not taken.
16 pb.response_readers_.emplace_back(stubs_[0]->AsyncPutParameter(
1082 8 pb.contexts_[index].get(), request, pb.cqs_.get()));
1083
1084 // Async call; completion via CQ tag
1085 8 auto& rpc = pb.response_readers_.back();
1086
1/2
✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
8 rpc->Finish(&response, &status, reinterpret_cast<void*>(index));
1087 8 }
1088
1089 // Store batch state in prewrite_batches_
1090
1/2
✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
8 prewrite_batches_.emplace(prewrite_id, std::move(pb));
1091 8 return prewrite_id;
1092 8 }
1093
1094 bool GRPCParameterClient::IsWriteDone(uint64_t write_id) {
1095 LOG(ERROR) << "IsWriteDone not implemented!";
1096 auto it = prewrite_batches_.find(write_id);
1097 if (it == prewrite_batches_.end()) {
1098 LOG(ERROR) << "Invalid prewrite_id: " << write_id;
1099 return false;
1100 }
1101 auto& pb = it->second;
1102 return (pb.completed_count_ == pb.batch_size_);
1103 }
1104
1105 8 void GRPCParameterClient::WaitForWrite(uint64_t write_id) {
1106
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 auto it = prewrite_batches_.find(write_id);
1107
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 8 times.
8 if (it == prewrite_batches_.end()) {
1108 LOG(ERROR) << "Invalid prewrite_id: " << write_id;
1109 return;
1110 }
1111 8 auto& pb = it->second;
1112
2/2
✓ Branch 0 taken 8 times.
✓ Branch 1 taken 8 times.
16 while (pb.completed_count_ < pb.batch_size_) {
1113 8 void* got_tag = nullptr;
1114 8 bool ok = false;
1115
2/4
✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 8 times.
8 if (!pb.cqs_->Next(&got_tag, &ok)) {
1116 break;
1117 }
1118
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
8 if (unlikely(!ok)) {
1119 LOG(ERROR) << "Completion queue returned not ok for write";
1120 continue;
1121 }
1122 8 pb.completed_count_++;
1123 }
1124 }
1125
1126 // Register GRPCParameterClient with the factory
1127 using BasePSClient = recstore::BasePSClient;
1128 FACTORY_REGISTER(BasePSClient, grpc, GRPCParameterClient, json);
1129