GCC Code Coverage Report


Directory: src/
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 0.0% 0 / 0 / 314
Functions: 0.0% 0 / 0 / 16
Branches: 0.0% 0 / 0 / 906

ps/grpc/grpc_ps_server.cpp
Line Branch Exec Source
1 #include <grpcpp/ext/proto_server_reflection_plugin.h>
2 #include <grpcpp/grpcpp.h>
3 #include <grpcpp/health_check_service_interface.h>
4
5 #include <cstdint>
6 #include <cstring>
7 #include <fstream>
8 #include <future>
9 #include <optional>
10 #include <stdexcept>
11 #include <string>
12 #include <thread>
13 #include <vector>
14
15 #include "base/array.h"
16 #include "base/base.h"
17 #include "base/flatc.h"
18 #include "base/init.h"
19 #include "base/timer.h"
20 #include "ps.grpc.pb.h"
21 #include "ps.pb.h"
22 #include "ps/base/base_ps_server.h"
23 #include "ps/base/cache_ps_impl.h"
24 #include "ps/base/parameters.h"
25 #include "recstore_config.h"
26 #include "src/base/config.h"
27
28 #ifdef ENABLE_PERF_REPORT
29 # include <chrono>
30 # include <cstdlib>
31
32 # include "base/report/report_client.h"
33 #else
34 # include "../report_client.h"
35 #endif
36
37 using grpc::Server;
38 using grpc::ServerBuilder;
39 using grpc::ServerContext;
40 using grpc::Status;
41
42 using recstoreps::CommandRequest;
43 using recstoreps::CommandResponse;
44 using recstoreps::GetParameterRequest;
45 using recstoreps::GetParameterResponse;
46 using recstoreps::InitEmbeddingTableRequest;
47 using recstoreps::InitEmbeddingTableResponse;
48 using recstoreps::PSCommand;
49 using recstoreps::PutParameterRequest;
50 using recstoreps::PutParameterResponse;
51 using recstoreps::UpdateParameterRequest;
52 using recstoreps::UpdateParameterResponse;
53
54 DEFINE_string(config_path, "", "config file path");
55 DEFINE_int32(grpc_local_shard_id,
56 -1,
57 "Only start the specified shard in multi-shard gRPC mode; "
58 "-1 means start all configured shards");
59
60 namespace {
61
62 void AppendShardSuffixIfPresent(
63 nlohmann::json& config_node, const char* key, int shard_id) {
64 if (!config_node.contains(key) || !config_node[key].is_string()) {
65 return;
66 }
67 config_node[key] =
68 config_node[key].get<std::string>() + "_" + std::to_string(shard_id);
69 }
70
71 void AppendShardSuffixToNestedFilePaths(nlohmann::json& node, int shard_id) {
72 if (node.is_object()) {
73 for (auto& item : node.items()) {
74 if (item.key() == "file_path" && item.value().is_string()) {
75 item.value() =
76 item.value().get<std::string>() + "_" + std::to_string(shard_id);
77 continue;
78 }
79 AppendShardSuffixToNestedFilePaths(item.value(), shard_id);
80 }
81 return;
82 }
83 if (node.is_array()) {
84 for (auto& item : node) {
85 AppendShardSuffixToNestedFilePaths(item, shard_id);
86 }
87 }
88 }
89
90 std::vector<nlohmann::json>
91 SelectGRPCShardConfigs(const nlohmann::json& cache_ps_config,
92 const std::optional<int>& local_shard_id) {
93 std::vector<nlohmann::json> selected;
94 if (!cache_ps_config.contains("servers") ||
95 !cache_ps_config["servers"].is_array()) {
96 return selected;
97 }
98
99 for (const auto& server_config : cache_ps_config["servers"]) {
100 if (!local_shard_id.has_value()) {
101 selected.push_back(server_config);
102 continue;
103 }
104 if (!server_config.contains("shard") ||
105 !server_config["shard"].is_number_integer()) {
106 continue;
107 }
108 if (server_config["shard"].get<int>() == *local_shard_id) {
109 selected.push_back(server_config);
110 }
111 }
112 return selected;
113 }
114
115 } // namespace
116
117 class ParameterServiceImpl final
118 : public recstoreps::ParameterService::Service {
119 public:
120 ParameterServiceImpl(CachePS* cache_ps) {
121 cache_ps_ = cache_ps;
122 start_time_ = std::chrono::steady_clock::now();
123 }
124 void ResetMetrics() {
125 total_get_requests_ = 0;
126 total_put_requests_ = 0;
127 total_get_keys_ = 0;
128 total_put_keys_ = 0;
129 total_get_bytes_ = 0;
130 total_put_bytes_ = 0;
131 start_time_ = std::chrono::steady_clock::now();
132 }
133 void PrintMetrics(const std::string& table_name = "grpc_ps_server_metrics",
134 const std::string& unique_id = "default_server") {
135 auto now = std::chrono::steady_clock::now();
136 double elapsed_s = std::chrono::duration<double>(now - start_time_).count();
137 if (elapsed_s > 0) {
138 double overall_qps =
139 (total_get_requests_ + total_put_requests_) / elapsed_s;
140 double overall_throughput_mbps =
141 ((total_get_bytes_ + total_put_bytes_) / 1024.0 / 1024.0) / elapsed_s;
142
143 // Report QPS and throughput metrics
144
145 // report(table_name.c_str(), unique_id.c_str(), "overall_qps",
146 // overall_qps); report(table_name.c_str(),
147 // unique_id.c_str(),
148 // "overall_throughput_mbps",
149 // overall_throughput_mbps);
150 }
151 }
152
153 private:
154 Status GetParameter(ServerContext* context,
155 const GetParameterRequest* request,
156 GetParameterResponse* reply) override {
157 #ifdef ENABLE_PERF_REPORT
158 auto start_time = std::chrono::high_resolution_clock::now();
159 #endif
160 base::ConstArray<uint64_t> keys_array(request->keys());
161 bool isPerf = request->has_perf() && request->perf();
162 if (isPerf) {
163 xmh::PerfCounter::Record("PS Get Keys", keys_array.Size());
164 }
165 xmh::Timer timer_ps_get_req("PS GetParameter Req");
166 ParameterCompressor compressor(std::numeric_limits<int>::max());
167 std::vector<std::string> blocks;
168 RECSTORE_LOG_EVERY_MS(INFO, 1000)
169 << "[PS] Getting " << keys_array.Size() << " keys";
170 int total_dim = 0;
171 #ifdef ENABLE_PERF_REPORT
172 auto cache_start_time = std::chrono::high_resolution_clock::now();
173 #endif
174 std::vector<ParameterPack> packs;
175 packs.reserve(keys_array.Size());
176 cache_ps_->GetParameterRun2Completion(keys_array, packs, 0);
177
178 for (auto& pack : packs) {
179 compressor.AddItem(pack, &blocks);
180 total_dim += pack.dim;
181 }
182 #ifdef ENABLE_PERF_REPORT
183 auto cache_end_time = std::chrono::high_resolution_clock::now();
184 auto cache_duration =
185 std::chrono::duration_cast<std::chrono::microseconds>(
186 cache_end_time - cache_start_time)
187 .count();
188 double start_us_for_cache =
189 std::chrono::duration_cast<std::chrono::microseconds>(
190 start_time.time_since_epoch())
191 .count();
192 std::string report_id_for_cache =
193 "grpc_server::GetParameter|" +
194 std::to_string(static_cast<uint64_t>(start_us_for_cache));
195 report("embread_stages",
196 report_id_for_cache.c_str(),
197 "cache_lookup_us",
198 static_cast<double>(cache_duration));
199 #endif
200
201 compressor.ToBlock(&blocks);
202 CHECK_EQ(blocks.size(), 1);
203 reply->mutable_parameter_value()->swap(blocks[0]);
204 total_get_requests_++;
205 total_get_keys_ += keys_array.Size();
206 total_get_bytes_ += total_dim * sizeof(float);
207
208 if (isPerf) {
209 timer_ps_get_req.end();
210 } else {
211 timer_ps_get_req.destroy();
212 }
213
214 #ifdef ENABLE_PERF_REPORT
215 auto end_time = std::chrono::high_resolution_clock::now();
216 double start_us =
217 std::chrono::duration_cast<std::chrono::microseconds>(
218 start_time.time_since_epoch())
219 .count();
220 auto duration =
221 std::chrono::duration_cast<std::chrono::microseconds>(
222 end_time - start_time)
223 .count();
224
225 std::string report_id = "grpc_server::GetParameter|" +
226 std::to_string(static_cast<uint64_t>(start_us));
227
228 std::string op_latency_key =
229 "EmbRead|" + std::to_string(static_cast<uint64_t>(start_us));
230 report("op_latency",
231 op_latency_key.c_str(),
232 "recserver_us",
233 static_cast<double>(duration));
234
235 report("embread_stages",
236 report_id.c_str(),
237 "duration_us",
238 static_cast<double>(duration));
239
240 report("embread_stages",
241 report_id.c_str(),
242 "request_size",
243 static_cast<double>(keys_array.Size()));
244
245 std::string unique_id =
246 "embread_debug|" + std::to_string(static_cast<uint64_t>(start_us));
247 FlameGraphData grpc_server_data = {
248 "grpc_ps_server::GetParameter",
249 start_us,
250 2, // level
251 static_cast<double>(duration),
252 static_cast<double>(duration)};
253 report_flame_graph(
254 "emb_read_flame_map", unique_id.c_str(), grpc_server_data);
255 #endif
256
257 return Status::OK;
258 }
259
260 Status Command(ServerContext* context,
261 const CommandRequest* request,
262 CommandResponse* reply) override {
263 if (request->command() == PSCommand::CLEAR_PS) {
264 LOG(WARNING) << "[PS Command] Clear All";
265 cache_ps_->Clear();
266 } else if (request->command() == PSCommand::RELOAD_PS) {
267 LOG(WARNING) << "[PS Command] Reload PS";
268 CHECK_NE(request->arg1().size(), 0);
269 CHECK_NE(request->arg2().size(), 0);
270 CHECK_EQ(request->arg1().size(), 1);
271 LOG(WARNING) << "model_config_path = " << request->arg1()[0];
272 for (int i = 0; i < request->arg2().size(); i++) {
273 LOG(WARNING) << fmt::format("emb_file {}: {}", i, request->arg2()[i]);
274 }
275 std::vector<std::string> arg1;
276 for (auto& each : request->arg1()) {
277 arg1.push_back(each);
278 }
279 std::vector<std::string> arg2;
280 for (auto& each : request->arg2()) {
281 arg2.push_back(each);
282 }
283
284 cache_ps_->Initialize(arg1, arg2);
285 } else if (request->command() == PSCommand::LOAD_FAKE_DATA) {
286 if (request->arg1_size() != 1 ||
287 static_cast<size_t>(request->arg1(0).size()) != sizeof(int64_t)) {
288 LOG(ERROR) << "LOAD_FAKE_DATA: arg1 must be one " << sizeof(int64_t)
289 << "-byte int64_t (requested reply payload size)";
290 return Status(grpc::StatusCode::INVALID_ARGUMENT,
291 "LOAD_FAKE_DATA invalid arg1 size");
292 }
293 int64_t payload_bytes = 0;
294 std::memcpy(&payload_bytes, request->arg1(0).data(), sizeof(int64_t));
295 if (payload_bytes < 0) {
296 LOG(ERROR) << "LOAD_FAKE_DATA: payload_bytes must be non-negative, got "
297 << payload_bytes;
298 return Status(grpc::StatusCode::INVALID_ARGUMENT,
299 "payload_bytes must be non-negative");
300 }
301 constexpr int64_t kMaxReplyPayload = 16 * 1024 * 1024;
302 if (payload_bytes > kMaxReplyPayload) {
303 LOG(ERROR) << "LOAD_FAKE_DATA: payload_bytes " << payload_bytes
304 << " exceeds cap " << kMaxReplyPayload;
305 return Status(grpc::StatusCode::INVALID_ARGUMENT, "payload too large");
306 }
307 std::string fake(static_cast<size_t>(payload_bytes), '\xab');
308 reply->set_reply(std::move(fake));
309 } else if (request->command() == PSCommand::DUMP_FAKE_DATA) {
310 if (request->arg1_size() != 1 ||
311 static_cast<size_t>(request->arg1(0).size()) != sizeof(int64_t)) {
312 LOG(ERROR) << "DUMP_FAKE_DATA: arg1 must be one " << sizeof(int64_t)
313 << "-byte int64_t (payload bytes n)";
314 return Status(grpc::StatusCode::INVALID_ARGUMENT,
315 "DUMP_FAKE_DATA invalid arg1 size");
316 }
317 int64_t n = 0;
318 std::memcpy(&n, request->arg1(0).data(), sizeof(int64_t));
319 if (n <= 0) {
320 LOG(ERROR) << "DUMP_FAKE_DATA: n must be positive";
321 return Status(grpc::StatusCode::INVALID_ARGUMENT,
322 "DUMP_FAKE_DATA n must be positive");
323 }
324 if (n % static_cast<int64_t>(sizeof(float)) != 0) {
325 LOG(ERROR) << "DUMP_FAKE_DATA: n must be a multiple of "
326 << sizeof(float);
327 return Status(grpc::StatusCode::INVALID_ARGUMENT,
328 "DUMP_FAKE_DATA n must be multiple of sizeof(float)");
329 }
330 constexpr int64_t kMaxDumpBytes = 64 * 1024 * 1024;
331 if (n > kMaxDumpBytes) {
332 LOG(ERROR) << "DUMP_FAKE_DATA: n exceeds cap " << kMaxDumpBytes;
333 return Status(
334 grpc::StatusCode::INVALID_ARGUMENT, "DUMP_FAKE_DATA n exceeds cap");
335 }
336 // Receive fake data payload (used for write bandwidth benchmarking)
337 reply->set_reply("ok");
338 } else {
339 LOG(FATAL) << "invalid command";
340 }
341 return Status::OK;
342 }
343
344 Status PutParameter(ServerContext* context,
345 const PutParameterRequest* request,
346 PutParameterResponse* reply) override {
347 #ifdef ENABLE_PERF_REPORT
348 auto start_time = std::chrono::high_resolution_clock::now();
349 #endif
350 const ParameterCompressReader* reader =
351 reinterpret_cast<const ParameterCompressReader*>(
352 request->parameter_value().data());
353 int size = reader->item_size();
354 LOG(INFO) << "[PS] PutParameter: " << size << " keys";
355 uint64_t total_bytes = 0;
356
357 for (int i = 0; i < size; i++) {
358 cache_ps_->PutSingleParameter(reader->item(i), 0);
359 total_bytes += reader->item(i)->dim * sizeof(float);
360 }
361 LOG(INFO) << "[PS] PutParameter done: " << size << " keys";
362 total_put_requests_++;
363 total_put_keys_ += size;
364 total_put_bytes_ += total_bytes;
365
366 #ifdef ENABLE_PERF_REPORT
367 auto end_time = std::chrono::high_resolution_clock::now();
368 double start_us_for_key =
369 std::chrono::duration_cast<std::chrono::microseconds>(
370 start_time.time_since_epoch())
371 .count();
372 auto duration =
373 std::chrono::duration_cast<std::chrono::microseconds>(
374 end_time - start_time)
375 .count();
376 std::string op_latency_key =
377 "EmbWrite|" + std::to_string(static_cast<uint64_t>(start_us_for_key));
378 report("op_latency",
379 op_latency_key.c_str(),
380 "recserver_us",
381 static_cast<double>(duration));
382 #endif
383
384 return Status::OK;
385 }
386
387 Status UpdateParameter(ServerContext* context,
388 const UpdateParameterRequest* request,
389 UpdateParameterResponse* reply) override {
390 #ifdef ENABLE_PERF_REPORT
391 auto start_time = std::chrono::high_resolution_clock::now();
392 uint64_t trace_id = 0;
393 const auto trace_it =
394 context->client_metadata().find("x-recstore-trace-id");
395 if (trace_it != context->client_metadata().end()) {
396 std::string trace_id_str(
397 trace_it->second.data(), trace_it->second.length());
398 trace_id = static_cast<uint64_t>(
399 std::strtoull(trace_id_str.c_str(), nullptr, 10));
400 }
401 #endif
402 bool success = false;
403 int size = 0;
404 std::string table_name;
405 #ifdef ENABLE_PERF_REPORT
406 auto before_cache_update_time = std::chrono::high_resolution_clock::now();
407 #endif
408 try {
409 table_name = request->table_name();
410 const ParameterCompressReader* reader =
411 reinterpret_cast<const ParameterCompressReader*>(
412 request->gradients().data());
413 size = reader->item_size();
414
415 #ifdef ENABLE_PERF_REPORT
416 before_cache_update_time = std::chrono::high_resolution_clock::now();
417 #endif
418 success = cache_ps_->UpdateParameter(table_name, reader, 0);
419
420 RECSTORE_LOG_EVERY_MS(INFO, 2000)
421 << "UpdateParameter: table=" << table_name << ", keys=" << size;
422
423 reply->set_success(success);
424 } catch (const std::exception& e) {
425 LOG(ERROR) << "UpdateParameter error: " << e.what();
426 reply->set_success(false);
427 }
428
429 #ifdef ENABLE_PERF_REPORT
430 auto end_time = std::chrono::high_resolution_clock::now();
431 double start_us_for_key =
432 std::chrono::duration_cast<std::chrono::microseconds>(
433 start_time.time_since_epoch())
434 .count();
435 auto duration =
436 std::chrono::duration_cast<std::chrono::microseconds>(
437 end_time - start_time)
438 .count();
439 std::string op_latency_key =
440 "EmbUpdate|" + std::to_string(static_cast<uint64_t>(start_us_for_key));
441 report("op_latency",
442 op_latency_key.c_str(),
443 "recserver_us",
444 static_cast<double>(duration));
445
446 auto backend_update_duration =
447 std::chrono::duration_cast<std::chrono::microseconds>(
448 end_time - before_cache_update_time)
449 .count();
450 const uint64_t effective_trace_id =
451 trace_id == 0 ? static_cast<uint64_t>(start_us_for_key) : trace_id;
452 std::string update_stage_id =
453 "grpc_server::EmbUpdate|" + std::to_string(effective_trace_id);
454 report("embupdate_stages",
455 update_stage_id.c_str(),
456 "server_total_us",
457 static_cast<double>(duration));
458 report("embupdate_stages",
459 update_stage_id.c_str(),
460 "server_backend_update_us",
461 static_cast<double>(backend_update_duration));
462 report("embupdate_stages",
463 update_stage_id.c_str(),
464 "server_request_size",
465 static_cast<double>(size));
466 report("embupdate_stages",
467 update_stage_id.c_str(),
468 "server_success",
469 success ? 1.0 : 0.0);
470 #endif
471
472 return Status::OK;
473 }
474
475 Status InitEmbeddingTable(ServerContext* context,
476 const InitEmbeddingTableRequest* request,
477 InitEmbeddingTableResponse* reply) override {
478 #ifdef ENABLE_PERF_REPORT
479 auto start_time = std::chrono::high_resolution_clock::now();
480 #endif
481 try {
482 if (request->has_config_payload()) {
483 auto payload = request->config_payload();
484 nlohmann::json cfg = nlohmann::json::parse(payload);
485 uint64_t num_embeddings = cfg.value("num_embeddings", 0);
486 uint64_t embedding_dim = cfg.value("embedding_dim", 0);
487 RECSTORE_LOG_EVERY_MS(INFO, 2000)
488 << "InitEmbeddingTable: table=" << request->table_name()
489 << ", num_embeddings=" << num_embeddings
490 << ", embedding_dim=" << embedding_dim;
491
492 bool init_success = cache_ps_->InitTable(
493 request->table_name(), num_embeddings, embedding_dim);
494 reply->set_success(init_success);
495 } else {
496 LOG(WARNING) << "InitEmbeddingTable called without config_payload";
497 reply->set_success(false);
498 }
499 } catch (const std::exception& e) {
500 LOG(ERROR) << "InitEmbeddingTable error: " << e.what();
501 reply->set_success(false);
502 }
503
504 #ifdef ENABLE_PERF_REPORT
505 auto end_time = std::chrono::high_resolution_clock::now();
506 double start_us_for_key =
507 std::chrono::duration_cast<std::chrono::microseconds>(
508 start_time.time_since_epoch())
509 .count();
510 auto duration =
511 std::chrono::duration_cast<std::chrono::microseconds>(
512 end_time - start_time)
513 .count();
514 std::string op_latency_key =
515 "InitEmbeddingTable|" +
516 std::to_string(static_cast<uint64_t>(start_us_for_key));
517 report("op_latency",
518 op_latency_key.c_str(),
519 "recserver_us",
520 static_cast<double>(duration));
521 #endif
522
523 return Status::OK;
524 }
525
526 private:
527 CachePS* cache_ps_;
528 std::atomic<uint64_t> total_get_requests_{0};
529 std::atomic<uint64_t> total_put_requests_{0};
530 std::atomic<uint64_t> total_get_keys_{0};
531 std::atomic<uint64_t> total_put_keys_{0};
532 std::atomic<uint64_t> total_get_bytes_{0};
533 std::atomic<uint64_t> total_put_bytes_{0};
534 std::chrono::steady_clock::time_point start_time_;
535 };
536
537 namespace recstore {
538 class GRPCParameterServer : public BaseParameterServer {
539 public:
540 GRPCParameterServer() = default;
541
542 void Run() {
543 // Check whether multi-shard mode is configured
544 int num_shards = 1; // default: single shard
545 if (config_["cache_ps"].contains("num_shards")) {
546 num_shards = config_["cache_ps"]["num_shards"];
547 }
548 const std::optional<int> local_shard_id =
549 FLAGS_grpc_local_shard_id >= 0
550 ? std::make_optional(FLAGS_grpc_local_shard_id)
551 : std::nullopt;
552
553 if (num_shards > 1) {
554 // Multi-server startup
555 std::cout << "Starting distributed parameter server (gRPC), number "
556 "of shards: "
557 << num_shards << std::endl;
558
559 if (!config_["cache_ps"].contains("servers")) {
560 LOG(FATAL) << "num_shards > 1 but cache_ps.servers is missing";
561 return;
562 }
563
564 const auto& cache_ps_config = config_["cache_ps"];
565 auto servers = SelectGRPCShardConfigs(cache_ps_config, local_shard_id);
566 const auto configured_servers = cache_ps_config["servers"];
567 if (configured_servers.size() != num_shards) {
568 LOG(FATAL) << "servers count (" << configured_servers.size()
569 << ") does not match num_shards (" << num_shards << ")";
570 return;
571 }
572 if (local_shard_id.has_value() && servers.empty()) {
573 LOG(FATAL) << "grpc_local_shard_id=" << *local_shard_id
574 << " is not present in cache_ps.servers";
575 return;
576 }
577 if (!local_shard_id.has_value() &&
578 servers.size() != configured_servers.size()) {
579 LOG(FATAL) << "Selected shard count (" << servers.size()
580 << ") does not match configured server count ("
581 << configured_servers.size() << ")";
582 return;
583 }
584
585 std::vector<std::thread> server_threads;
586
587 for (auto& server_config : servers) {
588 server_threads.emplace_back([this, server_config]() {
589 try {
590 std::string host = server_config["host"];
591 int port = server_config["port"];
592 int shard = server_config["shard"];
593
594 std::string server_address = host + ":" + std::to_string(port);
595
596 nlohmann::json shard_config = config_["cache_ps"];
597 if (shard_config.contains("base_kv_config") &&
598 shard_config["base_kv_config"].is_object()) {
599 auto& base_kv_config = shard_config["base_kv_config"];
600 AppendShardSuffixIfPresent(base_kv_config, "path", shard);
601 AppendShardSuffixIfPresent(base_kv_config, "rocksdb_path", shard);
602 AppendShardSuffixToNestedFilePaths(base_kv_config, shard);
603 LOG(INFO) << "gRPC shard " << shard
604 << " using base_kv_config: " << base_kv_config.dump();
605 }
606
607 auto cache_ps = std::make_unique<CachePS>(shard_config);
608 ParameterServiceImpl service(cache_ps.get());
609
610 grpc::EnableDefaultHealthCheckService(true);
611 grpc::reflection::InitProtoReflectionServerBuilderPlugin();
612 ServerBuilder builder;
613 builder.AddListeningPort(
614 server_address, grpc::InsecureServerCredentials());
615 builder.RegisterService(&service);
616 builder.SetMaxReceiveMessageSize(-1); // Unlimited
617 builder.SetMaxSendMessageSize(-1); // Unlimited
618 std::unique_ptr<Server> server(builder.BuildAndStart());
619
620 if (!server) {
621 std::string err_msg = fmt::format(
622 "FATAL: Failed to start gRPC server shard {} "
623 "on {}. "
624 "Port might be in use or invalid "
625 "configuration. "
626 "Check if port {} is already occupied.",
627 shard,
628 server_address,
629 port);
630 std::cerr << err_msg << std::endl;
631 LOG(FATAL) << err_msg;
632 return;
633 }
634 std::cout << "Server shard " << shard << " listening on "
635 << server_address << std::endl;
636 server->Wait();
637 } catch (const std::exception& e) {
638 std::cerr << "FATAL: Uncaught exception in shard thread: "
639 << e.what() << std::endl;
640 LOG(FATAL) << "Uncaught exception in shard thread: " << e.what();
641 } catch (...) {
642 std::cerr << "FATAL: Unknown exception in shard thread"
643 << std::endl;
644 LOG(FATAL) << "Unknown exception in shard thread";
645 }
646 });
647 }
648
649 // Wait for all server threads
650 for (auto& t : server_threads) {
651 t.join();
652 }
653 } else {
654 // Single-server startup
655 std::cout << "Starting single parameter server" << std::endl;
656 std::string server_address("0.0.0.0:15000");
657 auto cache_ps = std::make_unique<CachePS>(config_["cache_ps"]);
658 ParameterServiceImpl service(cache_ps.get());
659
660 std::atomic<bool> metrics_running{true};
661 std::thread metrics_thread([&service, &metrics_running]() {
662 while (metrics_running) {
663 std::this_thread::sleep_for(std::chrono::seconds(10));
664 service.PrintMetrics();
665 service.ResetMetrics();
666 }
667 });
668
669 grpc::EnableDefaultHealthCheckService(true);
670 grpc::reflection::InitProtoReflectionServerBuilderPlugin();
671 ServerBuilder builder;
672 builder.AddListeningPort(
673 server_address, grpc::InsecureServerCredentials());
674 builder.RegisterService(&service);
675 builder.SetMaxReceiveMessageSize(-1); // Unlimited
676 builder.SetMaxSendMessageSize(-1); // Unlimited
677 std::unique_ptr<Server> server(builder.BuildAndStart());
678 std::cerr << "sever built succesfully" << std::endl;
679 if (!server) {
680 std::string err_msg = fmt::format(
681 "FATAL: Failed to start gRPC server on {}. "
682 "Port might be in use or invalid configuration.",
683 server_address);
684 std::cerr << err_msg << std::endl;
685 LOG(FATAL) << err_msg;
686 metrics_running = false;
687 if (metrics_thread.joinable()) {
688 metrics_thread.join();
689 }
690 return;
691 }
692 std::cout << "Server listening on " << server_address << std::endl;
693 server->Wait();
694
695 metrics_running = false;
696 if (metrics_thread.joinable()) {
697 metrics_thread.join();
698 }
699 }
700 }
701 };
702
703 FACTORY_REGISTER(BaseParameterServer, GRPCParameterServer, GRPCParameterServer);
704
705 } // namespace recstore
706
707 #ifndef RECSTORE_NO_SERVER_MAIN
708 int main(int argc, char** argv) {
709 base::Init(&argc, &argv);
710 xmh::Reporter::StartReportThread(2000);
711 const std::string config_path =
712 FLAGS_config_path.empty()
713 ? base::ResolveRecStoreConfigPath().string()
714 : FLAGS_config_path;
715 std::ifstream config_file(config_path);
716 if (!config_file.is_open()) {
717 throw std::runtime_error("Cannot open config file: " + config_path);
718 }
719 nlohmann::json ex;
720 config_file >> ex;
721 recstore::GRPCParameterServer ps;
722 std::cout << "Parameter server config: " << ex.dump(2) << std::endl;
723 ps.Init(ex);
724 ps.Run();
725 return 0;
726 }
727 #endif
728