GCC Code Coverage Report


Directory: src/
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 4.3% 15 / 0 / 351
Functions: 11.1% 2 / 0 / 18
Branches: 2.6% 24 / 0 / 930

ps/brpc/brpc_ps_server.cpp
Line Branch Exec Source
1 #include "brpc_ps_server.h"
2
3 #include <brpc/server.h>
4 #include <fmt/core.h>
5 #include <gflags/gflags.h>
6
7 #include <chrono>
8 #include <cerrno>
9 #include <cstdint>
10 #include <cstring>
11 #include <fstream>
12 #include <memory>
13 #include <stdexcept>
14 #include <string>
15 #include <thread>
16 #include <vector>
17
18 #include "base/array.h"
19 #include "base/base.h"
20 #include "base/factory.h"
21 #include "base/flatc.h"
22 #include "base/log.h"
23 #include "base/timer.h"
24 #include "ps/base/base_ps_server.h"
25 #include "ps/base/cache_ps_impl.h"
26 #include "ps/base/parameters.h"
27 #include "ps_brpc.pb.h"
28 #include "recstore_config.h"
29 #include "src/base/config.h"
30
31 #ifdef ENABLE_PERF_REPORT
32 # include <chrono>
33 # include <cstdlib>
34 # include "base/report/report_client.h"
35 #endif
36
37 using recstoreps_brpc::CommandRequest;
38 using recstoreps_brpc::CommandResponse;
39 using recstoreps_brpc::GetParameterRequest;
40 using recstoreps_brpc::GetParameterResponse;
41 using recstoreps_brpc::InitEmbeddingTableRequest;
42 using recstoreps_brpc::InitEmbeddingTableResponse;
43 using recstoreps_brpc::PSCommand;
44 using recstoreps_brpc::PutParameterRequest;
45 using recstoreps_brpc::PutParameterResponse;
46 using recstoreps_brpc::UpdateParameterRequest;
47 using recstoreps_brpc::UpdateParameterResponse;
48
49 DEFINE_string(brpc_config_path, "", "config file path");
50 DEFINE_int32(brpc_server_port, 15000, "bRPC server port");
51 DEFINE_int32(local_shard_id,
52 -1,
53 "Only start the specified shard in multi-shard bRPC mode; "
54 "-1 means start all configured shards");
55 DEFINE_int32(brpc_server_num_threads,
56 0,
57 "Number of threads for bRPC server, 0 means auto");
58
59 namespace recstore {
60
61 namespace {
62
63 void AppendShardSuffixIfPresent(
64 nlohmann::json& config_node, const char* key, int shard_id) {
65 if (!config_node.contains(key) || !config_node[key].is_string()) {
66 return;
67 }
68 config_node[key] =
69 config_node[key].get<std::string>() + "_" + std::to_string(shard_id);
70 }
71
72 void AppendShardSuffixToNestedFilePaths(nlohmann::json& node, int shard_id) {
73 if (node.is_object()) {
74 for (auto& item : node.items()) {
75 if (item.key() == "file_path" && item.value().is_string()) {
76 item.value() =
77 item.value().get<std::string>() + "_" + std::to_string(shard_id);
78 continue;
79 }
80 AppendShardSuffixToNestedFilePaths(item.value(), shard_id);
81 }
82 return;
83 }
84 if (node.is_array()) {
85 for (auto& item : node) {
86 AppendShardSuffixToNestedFilePaths(item, shard_id);
87 }
88 }
89 }
90
91 bool ExtractPayloadBytes(
92 const brpc::Controller* cntl,
93 const std::string& proto_bytes,
94 std::string* payload_storage,
95 const char** payload_data,
96 int* payload_size) {
97 if (!cntl->request_attachment().empty()) {
98 payload_storage->clear();
99 cntl->request_attachment().copy_to(payload_storage);
100 *payload_data = payload_storage->data();
101 *payload_size = payload_storage->size();
102 return true;
103 }
104 if (!proto_bytes.empty()) {
105 *payload_data = proto_bytes.data();
106 *payload_size = proto_bytes.size();
107 return true;
108 }
109 *payload_data = nullptr;
110 *payload_size = 0;
111 return false;
112 }
113
114 std::vector<nlohmann::json>
115 6 SelectShardConfigsInternal(const nlohmann::json& cache_ps_config,
116 const std::optional<int>& local_shard_id) {
117 6 std::vector<nlohmann::json> selected;
118
3/6
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 6 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 6 times.
12 if (!cache_ps_config.contains("servers") ||
119
2/4
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 6 times.
6 !cache_ps_config["servers"].is_array()) {
120 return selected;
121 }
122
123
6/10
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✓ Branch 6 taken 12 times.
✗ Branch 7 not taken.
✓ Branch 9 taken 12 times.
✗ Branch 10 not taken.
✓ Branch 12 taken 18 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 12 times.
✓ Branch 15 taken 6 times.
18 for (const auto& server_config : cache_ps_config["servers"]) {
124
2/2
✓ Branch 1 taken 4 times.
✓ Branch 2 taken 8 times.
12 if (!local_shard_id.has_value()) {
125
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 selected.push_back(server_config);
126 4 continue;
127 }
128
3/6
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 8 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 8 times.
16 if (!server_config.contains("shard") ||
129
2/4
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 8 times.
8 !server_config["shard"].is_number_integer()) {
130 continue;
131 }
132
4/6
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 8 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 2 times.
✓ Branch 8 taken 6 times.
8 if (server_config["shard"].get<int>() == *local_shard_id) {
133
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 selected.push_back(server_config);
134 }
135 }
136 6 return selected;
137 }
138
139 } // namespace
140
141 std::vector<nlohmann::json>
142 6 SelectBRPCShardConfigs(const nlohmann::json& cache_ps_config,
143 const std::optional<int>& local_shard_id) {
144 6 return SelectShardConfigsInternal(cache_ps_config, local_shard_id);
145 }
146
147 BRPCParameterServiceImpl::BRPCParameterServiceImpl(CachePS* cache_ps)
148 : cache_ps_(cache_ps) {
149 start_time_ = std::chrono::steady_clock::now();
150 }
151
152 void BRPCParameterServiceImpl::ResetMetrics() {
153 total_get_requests_ = 0;
154 total_put_requests_ = 0;
155 total_get_keys_ = 0;
156 total_put_keys_ = 0;
157 total_get_bytes_ = 0;
158 total_put_bytes_ = 0;
159 start_time_ = std::chrono::steady_clock::now();
160 }
161
162 void BRPCParameterServiceImpl::PrintMetrics(const std::string& table_name,
163 const std::string& unique_id) {
164 auto now = std::chrono::steady_clock::now();
165 double elapsed_s = std::chrono::duration<double>(now - start_time_).count();
166 if (elapsed_s > 0) {
167 double overall_qps =
168 (total_get_requests_ + total_put_requests_) / elapsed_s;
169 double overall_throughput_mbps =
170 ((total_get_bytes_ + total_put_bytes_) / 1024.0 / 1024.0) / elapsed_s;
171
172 // Report QPS and throughput metrics
173 // report(table_name.c_str(), unique_id.c_str(), "overall_qps",
174 // overall_qps); report(table_name.c_str(),
175 // unique_id.c_str(),
176 // "overall_throughput_mbps",
177 // overall_throughput_mbps);
178 }
179 }
180
181 void BRPCParameterServiceImpl::GetParameter(
182 google::protobuf::RpcController* controller,
183 const GetParameterRequest* request,
184 GetParameterResponse* response,
185 google::protobuf::Closure* done) {
186 brpc::ClosureGuard done_guard(done);
187 brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
188
189 #ifdef ENABLE_PERF_REPORT
190 auto start_time = std::chrono::high_resolution_clock::now();
191 uint64_t trace_id = cntl->log_id();
192 std::string unique_id = "embread_debug" + std::to_string(trace_id);
193 #endif
194 std::string keys_storage;
195 const char* keys_data = nullptr;
196 int keys_size = 0;
197 ExtractPayloadBytes(
198 cntl, request->keys(), &keys_storage, &keys_data, &keys_size);
199 base::ConstArray<uint64_t> keys_array;
200 keys_array.SetData(keys_data, keys_size);
201 if (keys_size % static_cast<int>(sizeof(uint64_t)) != 0) {
202 LOG(ERROR) << "GetParameter invalid keys payload size=" << keys_size;
203 return;
204 }
205 bool isPerf = request->has_perf() && request->perf();
206
207 if (isPerf) {
208 xmh::PerfCounter::Record("PS Get Keys", keys_array.Size());
209 }
210
211 xmh::Timer timer_ps_get_req("PS GetParameter Req");
212 ParameterCompressor compressor;
213
214 RECSTORE_LOG_EVERY_MS(INFO, 1000)
215 << "[bRPC PS] Getting " << keys_array.Size() << " keys";
216
217 int total_dim = 0;
218
219 #ifdef ENABLE_PERF_REPORT
220 auto cache_loop_start = std::chrono::high_resolution_clock::now();
221 #endif
222 std::vector<ParameterPack> packs;
223 packs.reserve(keys_array.Size());
224 cache_ps_->GetParameterRun2Completion(keys_array, packs, 0);
225
226 for (auto& pack : packs) {
227 compressor.AddItem(pack, nullptr);
228 total_dim += pack.dim;
229 }
230 #ifdef ENABLE_PERF_REPORT
231 auto cache_loop_end = std::chrono::high_resolution_clock::now();
232 auto cache_loop_duration =
233 std::chrono::duration_cast<std::chrono::microseconds>(
234 cache_loop_end - cache_loop_start)
235 .count();
236 double cache_loop_start_us =
237 std::chrono::duration_cast<std::chrono::microseconds>(
238 cache_loop_start.time_since_epoch())
239 .count();
240
241 std::string report_id_cache =
242 "brpc_server::GetParameter|" +
243 std::to_string(static_cast<uint64_t>(cache_loop_start_us));
244 report("embread_stages",
245 report_id_cache.c_str(),
246 "cache_lookup_us",
247 static_cast<double>(cache_loop_duration));
248
249 FlameGraphData cache_loop_fg = {
250 "brpc_server::CacheGet_Loop",
251 cache_loop_start_us,
252 4, // level
253 static_cast<double>(cache_loop_duration),
254 static_cast<double>(cache_loop_duration)};
255 if (trace_id != 0) {
256 std::string cache_unique_id =
257 "embread_debug|" +
258 std::to_string(static_cast<uint64_t>(cache_loop_start_us));
259 report_flame_graph(
260 "emb_read_flame_map", cache_unique_id.c_str(), cache_loop_fg);
261 }
262
263 auto toblock_start = std::chrono::high_resolution_clock::now();
264 #endif
265
266 compressor.AppendToIOBuf(&cntl->response_attachment());
267
268 #ifdef ENABLE_PERF_REPORT
269 auto toblock_end = std::chrono::high_resolution_clock::now();
270 auto toblock_duration =
271 std::chrono::duration_cast<std::chrono::microseconds>(
272 toblock_end - toblock_start)
273 .count();
274 double toblock_start_us =
275 std::chrono::duration_cast<std::chrono::microseconds>(
276 toblock_start.time_since_epoch())
277 .count();
278 FlameGraphData toblock_fg = {
279 "brpc_server::Compressor_ToBlock",
280 toblock_start_us,
281 4, // level
282 static_cast<double>(toblock_duration),
283 static_cast<double>(toblock_duration)};
284 if (trace_id != 0) {
285 std::string toblock_unique_id =
286 "embread_debug|" +
287 std::to_string(static_cast<uint64_t>(toblock_start_us));
288 report_flame_graph(
289 "emb_read_flame_map", toblock_unique_id.c_str(), toblock_fg);
290 }
291 #endif
292
293 total_get_requests_++;
294 total_get_keys_ += keys_array.Size();
295 total_get_bytes_ += total_dim * sizeof(float);
296
297 if (isPerf) {
298 timer_ps_get_req.end();
299 } else {
300 timer_ps_get_req.destroy();
301 }
302
303 #ifdef ENABLE_PERF_REPORT
304 auto end_time = std::chrono::high_resolution_clock::now();
305 auto duration =
306 std::chrono::duration_cast<std::chrono::microseconds>(
307 end_time - start_time)
308 .count();
309 report("ps_server_latency",
310 "GetParameter",
311 "latency_us",
312 static_cast<double>(duration));
313
314 double start_us_for_key =
315 std::chrono::duration_cast<std::chrono::microseconds>(
316 start_time.time_since_epoch())
317 .count();
318 std::string op_latency_key =
319 "EmbRead|" + std::to_string(static_cast<uint64_t>(start_us_for_key));
320 report("op_latency",
321 op_latency_key.c_str(),
322 "recserver_us",
323 static_cast<double>(duration));
324
325 double start_us =
326 std::chrono::duration_cast<std::chrono::microseconds>(
327 start_time.time_since_epoch())
328 .count();
329
330 std::string report_id = "brpc_server::GetParameter|" +
331 std::to_string(static_cast<uint64_t>(start_us));
332
333 report("embread_stages",
334 report_id.c_str(),
335 "duration_us",
336 static_cast<double>(duration));
337
338 report("embread_stages",
339 report_id.c_str(),
340 "request_size",
341 static_cast<double>(keys_array.Size()));
342
343 FlameGraphData fg_data = {
344 "brpc_server::GetParameter",
345 start_us,
346 3, // level
347 static_cast<double>(duration),
348 static_cast<double>(duration)};
349 if (trace_id != 0) {
350 std::string req_unique_id =
351 "embread_debug|" + std::to_string(static_cast<uint64_t>(start_us));
352 report_flame_graph("emb_read_flame_map", req_unique_id.c_str(), fg_data);
353 }
354 #endif
355 }
356
357 void BRPCParameterServiceImpl::Command(
358 google::protobuf::RpcController* controller,
359 const CommandRequest* request,
360 CommandResponse* response,
361 google::protobuf::Closure* done) {
362 brpc::ClosureGuard done_guard(done);
363 brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
364
365 if (request->command() == recstoreps_brpc::PSCommand::CLEAR_PS) {
366 LOG(WARNING) << "[PS Command] Clear All";
367 cache_ps_->Clear();
368 } else if (request->command() == recstoreps_brpc::PSCommand::RELOAD_PS) {
369 LOG(WARNING) << "[PS Command] Reload PS";
370 CHECK_NE(request->arg1().size(), 0);
371 CHECK_NE(request->arg2().size(), 0);
372 CHECK_EQ(request->arg1().size(), 1);
373 LOG(WARNING) << "model_config_path = " << request->arg1()[0];
374 for (int i = 0; i < request->arg2().size(); i++) {
375 LOG(WARNING) << fmt::format("emb_file {}: {}", i, request->arg2()[i]);
376 }
377 std::vector<std::string> arg1;
378 for (auto& each : request->arg1()) {
379 arg1.push_back(each);
380 }
381 std::vector<std::string> arg2;
382 for (auto& each : request->arg2()) {
383 arg2.push_back(each);
384 }
385 cache_ps_->Initialize(arg1, arg2);
386 } else if (request->command() == recstoreps_brpc::PSCommand::LOAD_FAKE_DATA) {
387 if (request->arg1_size() != 1 ||
388 static_cast<size_t>(request->arg1(0).size()) != sizeof(int64_t)) {
389 LOG(ERROR) << "LOAD_FAKE_DATA: arg1 must be one " << sizeof(int64_t)
390 << "-byte int64_t (requested reply payload size)";
391 cntl->SetFailed(EINVAL, "LOAD_FAKE_DATA invalid arg1 size");
392 return;
393 }
394 int64_t payload_bytes = 0;
395 std::memcpy(&payload_bytes, request->arg1(0).data(), sizeof(int64_t));
396 if (payload_bytes < 0) {
397 LOG(ERROR) << "LOAD_FAKE_DATA: payload_bytes must be non-negative, got "
398 << payload_bytes;
399 cntl->SetFailed(
400 EINVAL, "LOAD_FAKE_DATA payload_bytes must be non-negative");
401 return;
402 }
403 constexpr int64_t kMaxReplyPayload = 16 * 1024 * 1024;
404 if (payload_bytes > kMaxReplyPayload) {
405 LOG(ERROR) << "LOAD_FAKE_DATA: payload_bytes " << payload_bytes
406 << " exceeds cap " << kMaxReplyPayload;
407 cntl->SetFailed(EINVAL, "LOAD_FAKE_DATA payload too large");
408 return;
409 }
410 std::string fake(static_cast<size_t>(payload_bytes), '\xab');
411 response->set_reply(std::move(fake));
412 } else if (request->command() == recstoreps_brpc::PSCommand::DUMP_FAKE_DATA) {
413 if (request->arg1_size() != 1 ||
414 static_cast<size_t>(request->arg1(0).size()) != sizeof(int64_t)) {
415 LOG(ERROR) << "DUMP_FAKE_DATA: arg1 must be one " << sizeof(int64_t)
416 << "-byte int64_t (payload bytes n)";
417 cntl->SetFailed(EINVAL, "DUMP_FAKE_DATA invalid arg1 size");
418 return;
419 }
420 int64_t n = 0;
421 std::memcpy(&n, request->arg1(0).data(), sizeof(int64_t));
422 if (n <= 0) {
423 LOG(ERROR) << "DUMP_FAKE_DATA: n must be positive";
424 cntl->SetFailed(EINVAL, "DUMP_FAKE_DATA n must be positive");
425 return;
426 }
427 if (n % static_cast<int64_t>(sizeof(float)) != 0) {
428 LOG(ERROR) << "DUMP_FAKE_DATA: n must be a multiple of " << sizeof(float);
429 cntl->SetFailed(
430 EINVAL, "DUMP_FAKE_DATA n must be multiple of sizeof(float)");
431 return;
432 }
433 constexpr int64_t kMaxDumpBytes = 64 * 1024 * 1024;
434 if (n > kMaxDumpBytes) {
435 LOG(ERROR) << "DUMP_FAKE_DATA: n exceeds cap " << kMaxDumpBytes;
436 cntl->SetFailed(EINVAL, "DUMP_FAKE_DATA n exceeds cap");
437 return;
438 }
439 response->set_reply("ok");
440 } else {
441 LOG(FATAL) << "invalid command";
442 }
443 }
444
445 void BRPCParameterServiceImpl::PutParameter(
446 google::protobuf::RpcController* controller,
447 const PutParameterRequest* request,
448 PutParameterResponse* response,
449 google::protobuf::Closure* done) {
450 brpc::ClosureGuard done_guard(done);
451
452 brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
453
454 std::string payload_storage;
455 const char* payload_data = nullptr;
456 int payload_size = 0;
457 if (!ExtractPayloadBytes(
458 cntl,
459 request->parameter_value(),
460 &payload_storage,
461 &payload_data,
462 &payload_size)) {
463 LOG(ERROR) << "PutParameter empty payload";
464 return;
465 }
466
467 #ifdef ENABLE_PERF_REPORT
468 auto start_time = std::chrono::high_resolution_clock::now();
469 #endif
470
471 const ParameterCompressReader* reader =
472 reinterpret_cast<const ParameterCompressReader*>(payload_data);
473 if (!reader->Valid(payload_size)) {
474 LOG(ERROR) << "PutParameter invalid payload, size=" << payload_size;
475 return;
476 }
477 int size = reader->item_size();
478 uint64_t total_bytes = 0;
479
480 for (int i = 0; i < size; i++) {
481 cache_ps_->PutSingleParameter(reader->item(i), 0);
482 total_bytes += reader->item(i)->dim * sizeof(float);
483 }
484
485 total_put_requests_++;
486 total_put_keys_ += size;
487 total_put_bytes_ += total_bytes;
488
489 #ifdef ENABLE_PERF_REPORT
490 auto end_time = std::chrono::high_resolution_clock::now();
491 auto duration =
492 std::chrono::duration_cast<std::chrono::microseconds>(
493 end_time - start_time)
494 .count();
495 report("ps_server_latency",
496 "PutParameter",
497 "latency_us",
498 static_cast<double>(duration));
499
500 double start_us_for_key =
501 std::chrono::duration_cast<std::chrono::microseconds>(
502 start_time.time_since_epoch())
503 .count();
504 std::string op_latency_key =
505 "EmbWrite|" + std::to_string(static_cast<uint64_t>(start_us_for_key));
506 report("op_latency",
507 op_latency_key.c_str(),
508 "recserver_us",
509 static_cast<double>(duration));
510 #endif
511 }
512
513 void BRPCParameterServiceImpl::UpdateParameter(
514 google::protobuf::RpcController* controller,
515 const UpdateParameterRequest* request,
516 UpdateParameterResponse* reply,
517 google::protobuf::Closure* done) {
518 brpc::ClosureGuard done_guard(done);
519
520 brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
521
522 #ifdef ENABLE_PERF_REPORT
523 auto start_time = std::chrono::high_resolution_clock::now();
524 uint64_t trace_id = 0;
525 const std::string* header_trace =
526 cntl->http_request().GetHeader("x-recstore-trace-id");
527 if (header_trace != nullptr && !header_trace->empty()) {
528 trace_id = static_cast<uint64_t>(
529 std::strtoull(header_trace->c_str(), nullptr, 10));
530 }
531 #endif
532 bool success = false;
533 int size = 0;
534 #ifdef ENABLE_PERF_REPORT
535 auto before_cache_update_time = std::chrono::high_resolution_clock::now();
536 #endif
537
538 try {
539 const std::string& table_name = request->table_name();
540
541 std::string payload_storage;
542 const char* payload_data = nullptr;
543 int payload_size = 0;
544 if (!ExtractPayloadBytes(
545 cntl,
546 request->gradients(),
547 &payload_storage,
548 &payload_data,
549 &payload_size)) {
550 throw std::runtime_error("UpdateParameter empty gradients payload");
551 }
552
553 const ParameterCompressReader* reader =
554 reinterpret_cast<const ParameterCompressReader*>(payload_data);
555 if (!reader->Valid(payload_size)) {
556 throw std::runtime_error("UpdateParameter invalid gradients payload");
557 }
558 size = reader->item_size();
559
560 #ifdef ENABLE_PERF_REPORT
561 before_cache_update_time = std::chrono::high_resolution_clock::now();
562 #endif
563 success = cache_ps_->UpdateParameter(table_name, reader, 0);
564
565 RECSTORE_LOG_EVERY_MS(INFO, 2000)
566 << "UpdateParameter: table=" << table_name << ", keys=" << size;
567
568 reply->set_success(success);
569 } catch (const std::exception& e) {
570 LOG(ERROR) << "UpdateParameter error: " << e.what();
571 reply->set_success(false);
572 }
573
574 #ifdef ENABLE_PERF_REPORT
575 auto end_time = std::chrono::high_resolution_clock::now();
576 auto duration =
577 std::chrono::duration_cast<std::chrono::microseconds>(
578 end_time - start_time)
579 .count();
580 report("ps_server_latency",
581 "UpdateParameter",
582 "latency_us",
583 static_cast<double>(duration));
584
585 double start_us_for_key =
586 std::chrono::duration_cast<std::chrono::microseconds>(
587 start_time.time_since_epoch())
588 .count();
589 std::string op_latency_key =
590 "EmbUpdate|" + std::to_string(static_cast<uint64_t>(start_us_for_key));
591 report("op_latency",
592 op_latency_key.c_str(),
593 "recserver_us",
594 static_cast<double>(duration));
595
596 auto backend_update_duration =
597 std::chrono::duration_cast<std::chrono::microseconds>(
598 end_time - before_cache_update_time)
599 .count();
600 const uint64_t effective_trace_id =
601 trace_id == 0 ? static_cast<uint64_t>(start_us_for_key) : trace_id;
602 std::string update_stage_id =
603 "brpc_server::EmbUpdate|" + std::to_string(effective_trace_id);
604 report("embupdate_stages",
605 update_stage_id.c_str(),
606 "server_total_us",
607 static_cast<double>(duration));
608 report("embupdate_stages",
609 update_stage_id.c_str(),
610 "server_backend_update_us",
611 static_cast<double>(backend_update_duration));
612 report("embupdate_stages",
613 update_stage_id.c_str(),
614 "server_request_size",
615 static_cast<double>(size));
616 report("embupdate_stages",
617 update_stage_id.c_str(),
618 "server_success",
619 success ? 1.0 : 0.0);
620 #endif
621 }
622
623 void BRPCParameterServiceImpl::InitEmbeddingTable(
624 google::protobuf::RpcController* controller,
625 const InitEmbeddingTableRequest* request,
626 InitEmbeddingTableResponse* reply,
627 google::protobuf::Closure* done) {
628 brpc::ClosureGuard done_guard(done);
629
630 #ifdef ENABLE_PERF_REPORT
631 auto start_time = std::chrono::high_resolution_clock::now();
632 #endif
633
634 try {
635 if (request->has_config_payload()) {
636 auto payload = request->config_payload();
637 nlohmann::json cfg = nlohmann::json::parse(payload);
638 uint64_t num_embeddings = cfg.value("num_embeddings", 0);
639 uint64_t embedding_dim = cfg.value("embedding_dim", 0);
640 RECSTORE_LOG_EVERY_MS(INFO, 2000)
641 << "InitEmbeddingTable: table=" << request->table_name()
642 << ", num_embeddings=" << num_embeddings
643 << ", embedding_dim=" << embedding_dim;
644
645 bool init_success = cache_ps_->InitTable(
646 request->table_name(), num_embeddings, embedding_dim);
647 reply->set_success(init_success);
648 } else {
649 LOG(WARNING) << "InitEmbeddingTable called without config_payload";
650 reply->set_success(false);
651 }
652 } catch (const std::exception& e) {
653 LOG(ERROR) << "InitEmbeddingTable error: " << e.what();
654 reply->set_success(false);
655 }
656
657 #ifdef ENABLE_PERF_REPORT
658 auto end_time = std::chrono::high_resolution_clock::now();
659 auto duration =
660 std::chrono::duration_cast<std::chrono::microseconds>(
661 end_time - start_time)
662 .count();
663 report("ps_server_latency",
664 "InitEmbeddingTable",
665 "latency_us",
666 static_cast<double>(duration));
667
668 double start_us_for_key =
669 std::chrono::duration_cast<std::chrono::microseconds>(
670 start_time.time_since_epoch())
671 .count();
672 std::string op_latency_key =
673 "InitEmbeddingTable|" +
674 std::to_string(static_cast<uint64_t>(start_us_for_key));
675 report("op_latency",
676 op_latency_key.c_str(),
677 "recserver_us",
678 static_cast<double>(duration));
679 #endif
680 }
681
682 class BRPCParameterServer : public BaseParameterServer {
683 public:
684 BRPCParameterServer() = default;
685
686 void Run() {
687 // Check whether multi-shard mode is configured
688 int num_shards = 1; // default: single shard
689 if (config_["cache_ps"].contains("num_shards")) {
690 num_shards = config_["cache_ps"]["num_shards"];
691 }
692 const std::optional<int> local_shard_id =
693 FLAGS_local_shard_id >= 0
694 ? std::make_optional(FLAGS_local_shard_id)
695 : std::nullopt;
696
697 if (num_shards > 1) {
698 // Multi-server startup
699 std::cout
700 << "Starting distributed parameter server (bRPC), number of shards: "
701 << num_shards << std::endl;
702
703 if (!config_["cache_ps"].contains("servers")) {
704 LOG(FATAL) << "num_shards > 1 but cache_ps.servers is missing";
705 return;
706 }
707
708 const auto& cache_ps_config = config_["cache_ps"];
709 auto servers =
710 SelectShardConfigsInternal(cache_ps_config, local_shard_id);
711 const auto configured_servers = cache_ps_config["servers"];
712 if (configured_servers.size() != num_shards) {
713 LOG(FATAL) << "servers 配置数量 (" << configured_servers.size()
714 << ") 与 num_shards (" << num_shards << ") 不匹配";
715 return;
716 }
717 if (local_shard_id.has_value() && servers.empty()) {
718 LOG(FATAL) << "local_shard_id=" << *local_shard_id
719 << " is not present in cache_ps.servers";
720 return;
721 }
722 if (!local_shard_id.has_value() &&
723 servers.size() != configured_servers.size()) {
724 LOG(FATAL) << "Selected shard count (" << servers.size()
725 << ") does not match configured server count ("
726 << configured_servers.size() << ")";
727 return;
728 }
729
730 std::vector<std::thread> server_threads;
731
732 for (auto& server_config : servers) {
733 server_threads.emplace_back([this, server_config]() {
734 std::string host = server_config["host"];
735 int port = server_config["port"];
736 int shard = server_config["shard"];
737
738 std::string server_address = host + ":" + std::to_string(port);
739
740 nlohmann::json shard_config = config_["cache_ps"];
741 shard_config["num_shards"] = 1;
742 shard_config["servers"] = nlohmann::json::array({server_config});
743 if (shard_config.contains("base_kv_config") &&
744 shard_config["base_kv_config"].is_object()) {
745 auto& base_kv_config = shard_config["base_kv_config"];
746 AppendShardSuffixIfPresent(base_kv_config, "path", shard);
747 AppendShardSuffixIfPresent(base_kv_config, "rocksdb_path", shard);
748 AppendShardSuffixToNestedFilePaths(base_kv_config, shard);
749 LOG(INFO) << "bRPC shard " << shard
750 << " using base_kv_config: " << base_kv_config.dump();
751 }
752
753 auto cache_ps = std::make_unique<CachePS>(shard_config);
754 auto service =
755 std::make_unique<BRPCParameterServiceImpl>(cache_ps.get());
756
757 brpc::Server server;
758 brpc::ServerOptions options;
759 options.num_threads = FLAGS_brpc_server_num_threads;
760
761 if (server.AddService(
762 service.get(), brpc::SERVER_DOESNT_OWN_SERVICE) != 0) {
763 LOG(ERROR) << "Failed to add service!";
764 return;
765 }
766
767 if (server.Start(server_address.c_str(), &options) != 0) {
768 LOG(ERROR) << "Failed to start bRPC server at " << server_address;
769 return;
770 }
771
772 std::cout << "bRPC Server shard " << shard << " listening on "
773 << server_address << std::endl;
774 server.RunUntilAskedToQuit();
775 });
776 }
777
778 // Wait for all server threads
779 for (auto& t : server_threads) {
780 t.join();
781 }
782 } else {
783 // Single-server startup
784 std::cout << "Starting single parameter server (bRPC)" << std::endl;
785 std::string server_address =
786 "0.0.0.0:" + std::to_string(FLAGS_brpc_server_port);
787 auto cache_ps = std::make_unique<CachePS>(config_["cache_ps"]);
788 auto service = std::make_unique<BRPCParameterServiceImpl>(cache_ps.get());
789
790 std::atomic<bool> metrics_running{true};
791 std::thread metrics_thread([&service, &metrics_running]() {
792 while (metrics_running) {
793 std::this_thread::sleep_for(std::chrono::seconds(10));
794 service->PrintMetrics();
795 service->ResetMetrics();
796 }
797 });
798
799 brpc::Server server;
800 brpc::ServerOptions options;
801 options.num_threads = FLAGS_brpc_server_num_threads;
802
803 if (server.AddService(service.get(), brpc::SERVER_DOESNT_OWN_SERVICE) !=
804 0) {
805 LOG(ERROR) << "Failed to add service!";
806 metrics_running = false;
807 if (metrics_thread.joinable()) {
808 metrics_thread.join();
809 }
810 return;
811 }
812
813 if (server.Start(server_address.c_str(), &options) != 0) {
814 LOG(ERROR) << "Failed to start bRPC server at " << server_address;
815 metrics_running = false;
816 if (metrics_thread.joinable()) {
817 metrics_thread.join();
818 }
819 return;
820 }
821
822 std::cout << "bRPC Server listening on " << server_address << std::endl;
823 server.RunUntilAskedToQuit();
824
825 metrics_running = false;
826 if (metrics_thread.joinable()) {
827 metrics_thread.join();
828 }
829 }
830 }
831 };
832
833 FACTORY_REGISTER(BaseParameterServer, BRPCParameterServer, BRPCParameterServer);
834
835 } // namespace recstore
836
837 #ifndef RECSTORE_NO_SERVER_MAIN
838 int main(int argc, char** argv) {
839 gflags::ParseCommandLineFlags(&argc, &argv, true);
840
841 const std::string config_path =
842 FLAGS_brpc_config_path.empty()
843 ? base::ResolveRecStoreConfigPath().string()
844 : FLAGS_brpc_config_path;
845 std::ifstream config_file(config_path);
846 if (!config_file.is_open()) {
847 throw std::runtime_error("Cannot open config file: " + config_path);
848 }
849 nlohmann::json ex;
850 config_file >> ex;
851
852 recstore::BRPCParameterServer ps;
853 std::cout << "bRPC Parameter server config: " << ex.dump(2) << std::endl;
854 ps.Init(ex);
855 ps.Run();
856
857 return 0;
858 }
859 #endif
860