ps/grpc/grpc_ps_server.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #include <grpcpp/ext/proto_server_reflection_plugin.h> | ||
| 2 | #include <grpcpp/grpcpp.h> | ||
| 3 | #include <grpcpp/health_check_service_interface.h> | ||
| 4 | |||
| 5 | #include <cstdint> | ||
| 6 | #include <cstring> | ||
| 7 | #include <fstream> | ||
| 8 | #include <future> | ||
| 9 | #include <optional> | ||
| 10 | #include <stdexcept> | ||
| 11 | #include <string> | ||
| 12 | #include <thread> | ||
| 13 | #include <vector> | ||
| 14 | |||
| 15 | #include "base/array.h" | ||
| 16 | #include "base/base.h" | ||
| 17 | #include "base/flatc.h" | ||
| 18 | #include "base/init.h" | ||
| 19 | #include "base/timer.h" | ||
| 20 | #include "ps.grpc.pb.h" | ||
| 21 | #include "ps.pb.h" | ||
| 22 | #include "ps/base/base_ps_server.h" | ||
| 23 | #include "ps/base/cache_ps_impl.h" | ||
| 24 | #include "ps/base/parameters.h" | ||
| 25 | #include "recstore_config.h" | ||
| 26 | #include "src/base/config.h" | ||
| 27 | |||
| 28 | #ifdef ENABLE_PERF_REPORT | ||
| 29 | # include <chrono> | ||
| 30 | # include <cstdlib> | ||
| 31 | |||
| 32 | # include "base/report/report_client.h" | ||
| 33 | #else | ||
| 34 | # include "../report_client.h" | ||
| 35 | #endif | ||
| 36 | |||
| 37 | using grpc::Server; | ||
| 38 | using grpc::ServerBuilder; | ||
| 39 | using grpc::ServerContext; | ||
| 40 | using grpc::Status; | ||
| 41 | |||
| 42 | using recstoreps::CommandRequest; | ||
| 43 | using recstoreps::CommandResponse; | ||
| 44 | using recstoreps::GetParameterRequest; | ||
| 45 | using recstoreps::GetParameterResponse; | ||
| 46 | using recstoreps::InitEmbeddingTableRequest; | ||
| 47 | using recstoreps::InitEmbeddingTableResponse; | ||
| 48 | using recstoreps::PSCommand; | ||
| 49 | using recstoreps::PutParameterRequest; | ||
| 50 | using recstoreps::PutParameterResponse; | ||
| 51 | using recstoreps::UpdateParameterRequest; | ||
| 52 | using recstoreps::UpdateParameterResponse; | ||
| 53 | |||
| 54 | DEFINE_string(config_path, "", "config file path"); | ||
| 55 | DEFINE_int32(grpc_local_shard_id, | ||
| 56 | -1, | ||
| 57 | "Only start the specified shard in multi-shard gRPC mode; " | ||
| 58 | "-1 means start all configured shards"); | ||
| 59 | |||
| 60 | namespace { | ||
| 61 | |||
| 62 | ✗ | void AppendShardSuffixIfPresent( | |
| 63 | nlohmann::json& config_node, const char* key, int shard_id) { | ||
| 64 | ✗ | if (!config_node.contains(key) || !config_node[key].is_string()) { | |
| 65 | ✗ | return; | |
| 66 | } | ||
| 67 | ✗ | config_node[key] = | |
| 68 | ✗ | config_node[key].get<std::string>() + "_" + std::to_string(shard_id); | |
| 69 | } | ||
| 70 | |||
| 71 | ✗ | void AppendShardSuffixToNestedFilePaths(nlohmann::json& node, int shard_id) { | |
| 72 | ✗ | if (node.is_object()) { | |
| 73 | ✗ | for (auto& item : node.items()) { | |
| 74 | ✗ | if (item.key() == "file_path" && item.value().is_string()) { | |
| 75 | ✗ | item.value() = | |
| 76 | ✗ | item.value().get<std::string>() + "_" + std::to_string(shard_id); | |
| 77 | ✗ | continue; | |
| 78 | } | ||
| 79 | ✗ | AppendShardSuffixToNestedFilePaths(item.value(), shard_id); | |
| 80 | ✗ | } | |
| 81 | ✗ | return; | |
| 82 | } | ||
| 83 | ✗ | if (node.is_array()) { | |
| 84 | ✗ | for (auto& item : node) { | |
| 85 | ✗ | AppendShardSuffixToNestedFilePaths(item, shard_id); | |
| 86 | } | ||
| 87 | } | ||
| 88 | } | ||
| 89 | |||
| 90 | std::vector<nlohmann::json> | ||
| 91 | ✗ | SelectGRPCShardConfigs(const nlohmann::json& cache_ps_config, | |
| 92 | const std::optional<int>& local_shard_id) { | ||
| 93 | ✗ | std::vector<nlohmann::json> selected; | |
| 94 | ✗ | if (!cache_ps_config.contains("servers") || | |
| 95 | ✗ | !cache_ps_config["servers"].is_array()) { | |
| 96 | ✗ | return selected; | |
| 97 | } | ||
| 98 | |||
| 99 | ✗ | for (const auto& server_config : cache_ps_config["servers"]) { | |
| 100 | ✗ | if (!local_shard_id.has_value()) { | |
| 101 | ✗ | selected.push_back(server_config); | |
| 102 | ✗ | continue; | |
| 103 | } | ||
| 104 | ✗ | if (!server_config.contains("shard") || | |
| 105 | ✗ | !server_config["shard"].is_number_integer()) { | |
| 106 | ✗ | continue; | |
| 107 | } | ||
| 108 | ✗ | if (server_config["shard"].get<int>() == *local_shard_id) { | |
| 109 | ✗ | selected.push_back(server_config); | |
| 110 | } | ||
| 111 | } | ||
| 112 | ✗ | return selected; | |
| 113 | ✗ | } | |
| 114 | |||
| 115 | } // namespace | ||
| 116 | |||
| 117 | class ParameterServiceImpl final | ||
| 118 | : public recstoreps::ParameterService::Service { | ||
| 119 | public: | ||
| 120 | ✗ | ParameterServiceImpl(CachePS* cache_ps) { | |
| 121 | ✗ | cache_ps_ = cache_ps; | |
| 122 | ✗ | start_time_ = std::chrono::steady_clock::now(); | |
| 123 | ✗ | } | |
| 124 | ✗ | void ResetMetrics() { | |
| 125 | ✗ | total_get_requests_ = 0; | |
| 126 | ✗ | total_put_requests_ = 0; | |
| 127 | ✗ | total_get_keys_ = 0; | |
| 128 | ✗ | total_put_keys_ = 0; | |
| 129 | ✗ | total_get_bytes_ = 0; | |
| 130 | ✗ | total_put_bytes_ = 0; | |
| 131 | ✗ | start_time_ = std::chrono::steady_clock::now(); | |
| 132 | ✗ | } | |
| 133 | ✗ | void PrintMetrics(const std::string& table_name = "grpc_ps_server_metrics", | |
| 134 | const std::string& unique_id = "default_server") { | ||
| 135 | ✗ | auto now = std::chrono::steady_clock::now(); | |
| 136 | ✗ | double elapsed_s = std::chrono::duration<double>(now - start_time_).count(); | |
| 137 | ✗ | if (elapsed_s > 0) { | |
| 138 | double overall_qps = | ||
| 139 | ✗ | (total_get_requests_ + total_put_requests_) / elapsed_s; | |
| 140 | double overall_throughput_mbps = | ||
| 141 | ✗ | ((total_get_bytes_ + total_put_bytes_) / 1024.0 / 1024.0) / elapsed_s; | |
| 142 | |||
| 143 | // Report QPS and throughput metrics | ||
| 144 | |||
| 145 | // report(table_name.c_str(), unique_id.c_str(), "overall_qps", | ||
| 146 | // overall_qps); report(table_name.c_str(), | ||
| 147 | // unique_id.c_str(), | ||
| 148 | // "overall_throughput_mbps", | ||
| 149 | // overall_throughput_mbps); | ||
| 150 | } | ||
| 151 | ✗ | } | |
| 152 | |||
| 153 | private: | ||
| 154 | ✗ | Status GetParameter(ServerContext* context, | |
| 155 | const GetParameterRequest* request, | ||
| 156 | GetParameterResponse* reply) override { | ||
| 157 | #ifdef ENABLE_PERF_REPORT | ||
| 158 | auto start_time = std::chrono::high_resolution_clock::now(); | ||
| 159 | #endif | ||
| 160 | ✗ | base::ConstArray<uint64_t> keys_array(request->keys()); | |
| 161 | ✗ | bool isPerf = request->has_perf() && request->perf(); | |
| 162 | ✗ | if (isPerf) { | |
| 163 | ✗ | xmh::PerfCounter::Record("PS Get Keys", keys_array.Size()); | |
| 164 | } | ||
| 165 | ✗ | xmh::Timer timer_ps_get_req("PS GetParameter Req"); | |
| 166 | ✗ | ParameterCompressor compressor(std::numeric_limits<int>::max()); | |
| 167 | ✗ | std::vector<std::string> blocks; | |
| 168 | ✗ | RECSTORE_LOG_EVERY_MS(INFO, 1000) | |
| 169 | ✗ | << "[PS] Getting " << keys_array.Size() << " keys"; | |
| 170 | ✗ | int total_dim = 0; | |
| 171 | #ifdef ENABLE_PERF_REPORT | ||
| 172 | auto cache_start_time = std::chrono::high_resolution_clock::now(); | ||
| 173 | #endif | ||
| 174 | ✗ | std::vector<ParameterPack> packs; | |
| 175 | ✗ | packs.reserve(keys_array.Size()); | |
| 176 | ✗ | cache_ps_->GetParameterRun2Completion(keys_array, packs, 0); | |
| 177 | |||
| 178 | ✗ | for (auto& pack : packs) { | |
| 179 | ✗ | compressor.AddItem(pack, &blocks); | |
| 180 | ✗ | total_dim += pack.dim; | |
| 181 | } | ||
| 182 | #ifdef ENABLE_PERF_REPORT | ||
| 183 | auto cache_end_time = std::chrono::high_resolution_clock::now(); | ||
| 184 | auto cache_duration = | ||
| 185 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 186 | cache_end_time - cache_start_time) | ||
| 187 | .count(); | ||
| 188 | double start_us_for_cache = | ||
| 189 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 190 | start_time.time_since_epoch()) | ||
| 191 | .count(); | ||
| 192 | std::string report_id_for_cache = | ||
| 193 | "grpc_server::GetParameter|" + | ||
| 194 | std::to_string(static_cast<uint64_t>(start_us_for_cache)); | ||
| 195 | report("embread_stages", | ||
| 196 | report_id_for_cache.c_str(), | ||
| 197 | "cache_lookup_us", | ||
| 198 | static_cast<double>(cache_duration)); | ||
| 199 | #endif | ||
| 200 | |||
| 201 | ✗ | compressor.ToBlock(&blocks); | |
| 202 | ✗ | CHECK_EQ(blocks.size(), 1); | |
| 203 | ✗ | reply->mutable_parameter_value()->swap(blocks[0]); | |
| 204 | ✗ | total_get_requests_++; | |
| 205 | ✗ | total_get_keys_ += keys_array.Size(); | |
| 206 | ✗ | total_get_bytes_ += total_dim * sizeof(float); | |
| 207 | |||
| 208 | ✗ | if (isPerf) { | |
| 209 | ✗ | timer_ps_get_req.end(); | |
| 210 | } else { | ||
| 211 | ✗ | timer_ps_get_req.destroy(); | |
| 212 | } | ||
| 213 | |||
| 214 | #ifdef ENABLE_PERF_REPORT | ||
| 215 | auto end_time = std::chrono::high_resolution_clock::now(); | ||
| 216 | double start_us = | ||
| 217 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 218 | start_time.time_since_epoch()) | ||
| 219 | .count(); | ||
| 220 | auto duration = | ||
| 221 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 222 | end_time - start_time) | ||
| 223 | .count(); | ||
| 224 | |||
| 225 | std::string report_id = "grpc_server::GetParameter|" + | ||
| 226 | std::to_string(static_cast<uint64_t>(start_us)); | ||
| 227 | |||
| 228 | std::string op_latency_key = | ||
| 229 | "EmbRead|" + std::to_string(static_cast<uint64_t>(start_us)); | ||
| 230 | report("op_latency", | ||
| 231 | op_latency_key.c_str(), | ||
| 232 | "recserver_us", | ||
| 233 | static_cast<double>(duration)); | ||
| 234 | |||
| 235 | report("embread_stages", | ||
| 236 | report_id.c_str(), | ||
| 237 | "duration_us", | ||
| 238 | static_cast<double>(duration)); | ||
| 239 | |||
| 240 | report("embread_stages", | ||
| 241 | report_id.c_str(), | ||
| 242 | "request_size", | ||
| 243 | static_cast<double>(keys_array.Size())); | ||
| 244 | |||
| 245 | std::string unique_id = | ||
| 246 | "embread_debug|" + std::to_string(static_cast<uint64_t>(start_us)); | ||
| 247 | FlameGraphData grpc_server_data = { | ||
| 248 | "grpc_ps_server::GetParameter", | ||
| 249 | start_us, | ||
| 250 | 2, // level | ||
| 251 | static_cast<double>(duration), | ||
| 252 | static_cast<double>(duration)}; | ||
| 253 | report_flame_graph( | ||
| 254 | "emb_read_flame_map", unique_id.c_str(), grpc_server_data); | ||
| 255 | #endif | ||
| 256 | |||
| 257 | ✗ | return Status::OK; | |
| 258 | ✗ | } | |
| 259 | |||
| 260 | ✗ | Status Command(ServerContext* context, | |
| 261 | const CommandRequest* request, | ||
| 262 | CommandResponse* reply) override { | ||
| 263 | ✗ | if (request->command() == PSCommand::CLEAR_PS) { | |
| 264 | ✗ | LOG(WARNING) << "[PS Command] Clear All"; | |
| 265 | ✗ | cache_ps_->Clear(); | |
| 266 | ✗ | } else if (request->command() == PSCommand::RELOAD_PS) { | |
| 267 | ✗ | LOG(WARNING) << "[PS Command] Reload PS"; | |
| 268 | ✗ | CHECK_NE(request->arg1().size(), 0); | |
| 269 | ✗ | CHECK_NE(request->arg2().size(), 0); | |
| 270 | ✗ | CHECK_EQ(request->arg1().size(), 1); | |
| 271 | ✗ | LOG(WARNING) << "model_config_path = " << request->arg1()[0]; | |
| 272 | ✗ | for (int i = 0; i < request->arg2().size(); i++) { | |
| 273 | ✗ | LOG(WARNING) << fmt::format("emb_file {}: {}", i, request->arg2()[i]); | |
| 274 | } | ||
| 275 | ✗ | std::vector<std::string> arg1; | |
| 276 | ✗ | for (auto& each : request->arg1()) { | |
| 277 | ✗ | arg1.push_back(each); | |
| 278 | } | ||
| 279 | ✗ | std::vector<std::string> arg2; | |
| 280 | ✗ | for (auto& each : request->arg2()) { | |
| 281 | ✗ | arg2.push_back(each); | |
| 282 | } | ||
| 283 | |||
| 284 | ✗ | cache_ps_->Initialize(arg1, arg2); | |
| 285 | ✗ | } else if (request->command() == PSCommand::LOAD_FAKE_DATA) { | |
| 286 | ✗ | if (request->arg1_size() != 1 || | |
| 287 | ✗ | static_cast<size_t>(request->arg1(0).size()) != sizeof(int64_t)) { | |
| 288 | ✗ | LOG(ERROR) << "LOAD_FAKE_DATA: arg1 must be one " << sizeof(int64_t) | |
| 289 | ✗ | << "-byte int64_t (requested reply payload size)"; | |
| 290 | return Status(grpc::StatusCode::INVALID_ARGUMENT, | ||
| 291 | ✗ | "LOAD_FAKE_DATA invalid arg1 size"); | |
| 292 | } | ||
| 293 | ✗ | int64_t payload_bytes = 0; | |
| 294 | ✗ | std::memcpy(&payload_bytes, request->arg1(0).data(), sizeof(int64_t)); | |
| 295 | ✗ | if (payload_bytes < 0) { | |
| 296 | ✗ | LOG(ERROR) << "LOAD_FAKE_DATA: payload_bytes must be non-negative, got " | |
| 297 | ✗ | << payload_bytes; | |
| 298 | return Status(grpc::StatusCode::INVALID_ARGUMENT, | ||
| 299 | ✗ | "payload_bytes must be non-negative"); | |
| 300 | } | ||
| 301 | ✗ | constexpr int64_t kMaxReplyPayload = 16 * 1024 * 1024; | |
| 302 | ✗ | if (payload_bytes > kMaxReplyPayload) { | |
| 303 | ✗ | LOG(ERROR) << "LOAD_FAKE_DATA: payload_bytes " << payload_bytes | |
| 304 | ✗ | << " exceeds cap " << kMaxReplyPayload; | |
| 305 | ✗ | return Status(grpc::StatusCode::INVALID_ARGUMENT, "payload too large"); | |
| 306 | } | ||
| 307 | ✗ | std::string fake(static_cast<size_t>(payload_bytes), '\xab'); | |
| 308 | ✗ | reply->set_reply(std::move(fake)); | |
| 309 | ✗ | } else if (request->command() == PSCommand::DUMP_FAKE_DATA) { | |
| 310 | ✗ | if (request->arg1_size() != 1 || | |
| 311 | ✗ | static_cast<size_t>(request->arg1(0).size()) != sizeof(int64_t)) { | |
| 312 | ✗ | LOG(ERROR) << "DUMP_FAKE_DATA: arg1 must be one " << sizeof(int64_t) | |
| 313 | ✗ | << "-byte int64_t (payload bytes n)"; | |
| 314 | return Status(grpc::StatusCode::INVALID_ARGUMENT, | ||
| 315 | ✗ | "DUMP_FAKE_DATA invalid arg1 size"); | |
| 316 | } | ||
| 317 | ✗ | int64_t n = 0; | |
| 318 | ✗ | std::memcpy(&n, request->arg1(0).data(), sizeof(int64_t)); | |
| 319 | ✗ | if (n <= 0) { | |
| 320 | ✗ | LOG(ERROR) << "DUMP_FAKE_DATA: n must be positive"; | |
| 321 | return Status(grpc::StatusCode::INVALID_ARGUMENT, | ||
| 322 | ✗ | "DUMP_FAKE_DATA n must be positive"); | |
| 323 | } | ||
| 324 | ✗ | if (n % static_cast<int64_t>(sizeof(float)) != 0) { | |
| 325 | ✗ | LOG(ERROR) << "DUMP_FAKE_DATA: n must be a multiple of " | |
| 326 | ✗ | << sizeof(float); | |
| 327 | return Status(grpc::StatusCode::INVALID_ARGUMENT, | ||
| 328 | ✗ | "DUMP_FAKE_DATA n must be multiple of sizeof(float)"); | |
| 329 | } | ||
| 330 | ✗ | constexpr int64_t kMaxDumpBytes = 64 * 1024 * 1024; | |
| 331 | ✗ | if (n > kMaxDumpBytes) { | |
| 332 | ✗ | LOG(ERROR) << "DUMP_FAKE_DATA: n exceeds cap " << kMaxDumpBytes; | |
| 333 | return Status( | ||
| 334 | ✗ | grpc::StatusCode::INVALID_ARGUMENT, "DUMP_FAKE_DATA n exceeds cap"); | |
| 335 | } | ||
| 336 | // Receive fake data payload (used for write bandwidth benchmarking) | ||
| 337 | reply->set_reply("ok"); | ||
| 338 | } else { | ||
| 339 | ✗ | LOG(FATAL) << "invalid command"; | |
| 340 | } | ||
| 341 | ✗ | return Status::OK; | |
| 342 | } | ||
| 343 | |||
| 344 | ✗ | Status PutParameter(ServerContext* context, | |
| 345 | const PutParameterRequest* request, | ||
| 346 | PutParameterResponse* reply) override { | ||
| 347 | #ifdef ENABLE_PERF_REPORT | ||
| 348 | auto start_time = std::chrono::high_resolution_clock::now(); | ||
| 349 | #endif | ||
| 350 | const ParameterCompressReader* reader = | ||
| 351 | reinterpret_cast<const ParameterCompressReader*>( | ||
| 352 | ✗ | request->parameter_value().data()); | |
| 353 | ✗ | int size = reader->item_size(); | |
| 354 | ✗ | LOG(INFO) << "[PS] PutParameter: " << size << " keys"; | |
| 355 | ✗ | uint64_t total_bytes = 0; | |
| 356 | |||
| 357 | ✗ | for (int i = 0; i < size; i++) { | |
| 358 | ✗ | cache_ps_->PutSingleParameter(reader->item(i), 0); | |
| 359 | ✗ | total_bytes += reader->item(i)->dim * sizeof(float); | |
| 360 | } | ||
| 361 | ✗ | LOG(INFO) << "[PS] PutParameter done: " << size << " keys"; | |
| 362 | ✗ | total_put_requests_++; | |
| 363 | ✗ | total_put_keys_ += size; | |
| 364 | ✗ | total_put_bytes_ += total_bytes; | |
| 365 | |||
| 366 | #ifdef ENABLE_PERF_REPORT | ||
| 367 | auto end_time = std::chrono::high_resolution_clock::now(); | ||
| 368 | double start_us_for_key = | ||
| 369 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 370 | start_time.time_since_epoch()) | ||
| 371 | .count(); | ||
| 372 | auto duration = | ||
| 373 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 374 | end_time - start_time) | ||
| 375 | .count(); | ||
| 376 | std::string op_latency_key = | ||
| 377 | "EmbWrite|" + std::to_string(static_cast<uint64_t>(start_us_for_key)); | ||
| 378 | report("op_latency", | ||
| 379 | op_latency_key.c_str(), | ||
| 380 | "recserver_us", | ||
| 381 | static_cast<double>(duration)); | ||
| 382 | #endif | ||
| 383 | |||
| 384 | ✗ | return Status::OK; | |
| 385 | } | ||
| 386 | |||
| 387 | ✗ | Status UpdateParameter(ServerContext* context, | |
| 388 | const UpdateParameterRequest* request, | ||
| 389 | UpdateParameterResponse* reply) override { | ||
| 390 | #ifdef ENABLE_PERF_REPORT | ||
| 391 | auto start_time = std::chrono::high_resolution_clock::now(); | ||
| 392 | uint64_t trace_id = 0; | ||
| 393 | const auto trace_it = | ||
| 394 | context->client_metadata().find("x-recstore-trace-id"); | ||
| 395 | if (trace_it != context->client_metadata().end()) { | ||
| 396 | std::string trace_id_str( | ||
| 397 | trace_it->second.data(), trace_it->second.length()); | ||
| 398 | trace_id = static_cast<uint64_t>( | ||
| 399 | std::strtoull(trace_id_str.c_str(), nullptr, 10)); | ||
| 400 | } | ||
| 401 | #endif | ||
| 402 | ✗ | bool success = false; | |
| 403 | ✗ | int size = 0; | |
| 404 | ✗ | std::string table_name; | |
| 405 | #ifdef ENABLE_PERF_REPORT | ||
| 406 | auto before_cache_update_time = std::chrono::high_resolution_clock::now(); | ||
| 407 | #endif | ||
| 408 | try { | ||
| 409 | ✗ | table_name = request->table_name(); | |
| 410 | const ParameterCompressReader* reader = | ||
| 411 | reinterpret_cast<const ParameterCompressReader*>( | ||
| 412 | ✗ | request->gradients().data()); | |
| 413 | ✗ | size = reader->item_size(); | |
| 414 | |||
| 415 | #ifdef ENABLE_PERF_REPORT | ||
| 416 | before_cache_update_time = std::chrono::high_resolution_clock::now(); | ||
| 417 | #endif | ||
| 418 | ✗ | success = cache_ps_->UpdateParameter(table_name, reader, 0); | |
| 419 | |||
| 420 | ✗ | RECSTORE_LOG_EVERY_MS(INFO, 2000) | |
| 421 | ✗ | << "UpdateParameter: table=" << table_name << ", keys=" << size; | |
| 422 | |||
| 423 | ✗ | reply->set_success(success); | |
| 424 | ✗ | } catch (const std::exception& e) { | |
| 425 | ✗ | LOG(ERROR) << "UpdateParameter error: " << e.what(); | |
| 426 | ✗ | reply->set_success(false); | |
| 427 | ✗ | } | |
| 428 | |||
| 429 | #ifdef ENABLE_PERF_REPORT | ||
| 430 | auto end_time = std::chrono::high_resolution_clock::now(); | ||
| 431 | double start_us_for_key = | ||
| 432 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 433 | start_time.time_since_epoch()) | ||
| 434 | .count(); | ||
| 435 | auto duration = | ||
| 436 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 437 | end_time - start_time) | ||
| 438 | .count(); | ||
| 439 | std::string op_latency_key = | ||
| 440 | "EmbUpdate|" + std::to_string(static_cast<uint64_t>(start_us_for_key)); | ||
| 441 | report("op_latency", | ||
| 442 | op_latency_key.c_str(), | ||
| 443 | "recserver_us", | ||
| 444 | static_cast<double>(duration)); | ||
| 445 | |||
| 446 | auto backend_update_duration = | ||
| 447 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 448 | end_time - before_cache_update_time) | ||
| 449 | .count(); | ||
| 450 | const uint64_t effective_trace_id = | ||
| 451 | trace_id == 0 ? static_cast<uint64_t>(start_us_for_key) : trace_id; | ||
| 452 | std::string update_stage_id = | ||
| 453 | "grpc_server::EmbUpdate|" + std::to_string(effective_trace_id); | ||
| 454 | report("embupdate_stages", | ||
| 455 | update_stage_id.c_str(), | ||
| 456 | "server_total_us", | ||
| 457 | static_cast<double>(duration)); | ||
| 458 | report("embupdate_stages", | ||
| 459 | update_stage_id.c_str(), | ||
| 460 | "server_backend_update_us", | ||
| 461 | static_cast<double>(backend_update_duration)); | ||
| 462 | report("embupdate_stages", | ||
| 463 | update_stage_id.c_str(), | ||
| 464 | "server_request_size", | ||
| 465 | static_cast<double>(size)); | ||
| 466 | report("embupdate_stages", | ||
| 467 | update_stage_id.c_str(), | ||
| 468 | "server_success", | ||
| 469 | success ? 1.0 : 0.0); | ||
| 470 | #endif | ||
| 471 | |||
| 472 | ✗ | return Status::OK; | |
| 473 | ✗ | } | |
| 474 | |||
| 475 | ✗ | Status InitEmbeddingTable(ServerContext* context, | |
| 476 | const InitEmbeddingTableRequest* request, | ||
| 477 | InitEmbeddingTableResponse* reply) override { | ||
| 478 | #ifdef ENABLE_PERF_REPORT | ||
| 479 | auto start_time = std::chrono::high_resolution_clock::now(); | ||
| 480 | #endif | ||
| 481 | try { | ||
| 482 | ✗ | if (request->has_config_payload()) { | |
| 483 | ✗ | auto payload = request->config_payload(); | |
| 484 | ✗ | nlohmann::json cfg = nlohmann::json::parse(payload); | |
| 485 | ✗ | uint64_t num_embeddings = cfg.value("num_embeddings", 0); | |
| 486 | ✗ | uint64_t embedding_dim = cfg.value("embedding_dim", 0); | |
| 487 | ✗ | RECSTORE_LOG_EVERY_MS(INFO, 2000) | |
| 488 | ✗ | << "InitEmbeddingTable: table=" << request->table_name() | |
| 489 | ✗ | << ", num_embeddings=" << num_embeddings | |
| 490 | ✗ | << ", embedding_dim=" << embedding_dim; | |
| 491 | |||
| 492 | ✗ | bool init_success = cache_ps_->InitTable( | |
| 493 | request->table_name(), num_embeddings, embedding_dim); | ||
| 494 | ✗ | reply->set_success(init_success); | |
| 495 | ✗ | } else { | |
| 496 | ✗ | LOG(WARNING) << "InitEmbeddingTable called without config_payload"; | |
| 497 | ✗ | reply->set_success(false); | |
| 498 | } | ||
| 499 | ✗ | } catch (const std::exception& e) { | |
| 500 | ✗ | LOG(ERROR) << "InitEmbeddingTable error: " << e.what(); | |
| 501 | ✗ | reply->set_success(false); | |
| 502 | ✗ | } | |
| 503 | |||
| 504 | #ifdef ENABLE_PERF_REPORT | ||
| 505 | auto end_time = std::chrono::high_resolution_clock::now(); | ||
| 506 | double start_us_for_key = | ||
| 507 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 508 | start_time.time_since_epoch()) | ||
| 509 | .count(); | ||
| 510 | auto duration = | ||
| 511 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 512 | end_time - start_time) | ||
| 513 | .count(); | ||
| 514 | std::string op_latency_key = | ||
| 515 | "InitEmbeddingTable|" + | ||
| 516 | std::to_string(static_cast<uint64_t>(start_us_for_key)); | ||
| 517 | report("op_latency", | ||
| 518 | op_latency_key.c_str(), | ||
| 519 | "recserver_us", | ||
| 520 | static_cast<double>(duration)); | ||
| 521 | #endif | ||
| 522 | |||
| 523 | ✗ | return Status::OK; | |
| 524 | } | ||
| 525 | |||
| 526 | private: | ||
| 527 | CachePS* cache_ps_; | ||
| 528 | std::atomic<uint64_t> total_get_requests_{0}; | ||
| 529 | std::atomic<uint64_t> total_put_requests_{0}; | ||
| 530 | std::atomic<uint64_t> total_get_keys_{0}; | ||
| 531 | std::atomic<uint64_t> total_put_keys_{0}; | ||
| 532 | std::atomic<uint64_t> total_get_bytes_{0}; | ||
| 533 | std::atomic<uint64_t> total_put_bytes_{0}; | ||
| 534 | std::chrono::steady_clock::time_point start_time_; | ||
| 535 | }; | ||
| 536 | |||
| 537 | namespace recstore { | ||
| 538 | class GRPCParameterServer : public BaseParameterServer { | ||
| 539 | public: | ||
| 540 | ✗ | GRPCParameterServer() = default; | |
| 541 | |||
| 542 | ✗ | void Run() { | |
| 543 | // Check whether multi-shard mode is configured | ||
| 544 | ✗ | int num_shards = 1; // default: single shard | |
| 545 | ✗ | if (config_["cache_ps"].contains("num_shards")) { | |
| 546 | ✗ | num_shards = config_["cache_ps"]["num_shards"]; | |
| 547 | } | ||
| 548 | const std::optional<int> local_shard_id = | ||
| 549 | ✗ | FLAGS_grpc_local_shard_id >= 0 | |
| 550 | ✗ | ? std::make_optional(FLAGS_grpc_local_shard_id) | |
| 551 | ✗ | : std::nullopt; | |
| 552 | |||
| 553 | ✗ | if (num_shards > 1) { | |
| 554 | // Multi-server startup | ||
| 555 | std::cout << "Starting distributed parameter server (gRPC), number " | ||
| 556 | ✗ | "of shards: " | |
| 557 | ✗ | << num_shards << std::endl; | |
| 558 | |||
| 559 | ✗ | if (!config_["cache_ps"].contains("servers")) { | |
| 560 | ✗ | LOG(FATAL) << "num_shards > 1 but cache_ps.servers is missing"; | |
| 561 | ✗ | return; | |
| 562 | } | ||
| 563 | |||
| 564 | ✗ | const auto& cache_ps_config = config_["cache_ps"]; | |
| 565 | ✗ | auto servers = SelectGRPCShardConfigs(cache_ps_config, local_shard_id); | |
| 566 | ✗ | const auto configured_servers = cache_ps_config["servers"]; | |
| 567 | ✗ | if (configured_servers.size() != num_shards) { | |
| 568 | ✗ | LOG(FATAL) << "servers count (" << configured_servers.size() | |
| 569 | ✗ | << ") does not match num_shards (" << num_shards << ")"; | |
| 570 | return; | ||
| 571 | } | ||
| 572 | ✗ | if (local_shard_id.has_value() && servers.empty()) { | |
| 573 | ✗ | LOG(FATAL) << "grpc_local_shard_id=" << *local_shard_id | |
| 574 | ✗ | << " is not present in cache_ps.servers"; | |
| 575 | return; | ||
| 576 | } | ||
| 577 | ✗ | if (!local_shard_id.has_value() && | |
| 578 | ✗ | servers.size() != configured_servers.size()) { | |
| 579 | ✗ | LOG(FATAL) << "Selected shard count (" << servers.size() | |
| 580 | ✗ | << ") does not match configured server count (" | |
| 581 | ✗ | << configured_servers.size() << ")"; | |
| 582 | return; | ||
| 583 | } | ||
| 584 | |||
| 585 | ✗ | std::vector<std::thread> server_threads; | |
| 586 | |||
| 587 | ✗ | for (auto& server_config : servers) { | |
| 588 | ✗ | server_threads.emplace_back([this, server_config]() { | |
| 589 | try { | ||
| 590 | ✗ | std::string host = server_config["host"]; | |
| 591 | ✗ | int port = server_config["port"]; | |
| 592 | ✗ | int shard = server_config["shard"]; | |
| 593 | |||
| 594 | ✗ | std::string server_address = host + ":" + std::to_string(port); | |
| 595 | |||
| 596 | ✗ | nlohmann::json shard_config = config_["cache_ps"]; | |
| 597 | ✗ | if (shard_config.contains("base_kv_config") && | |
| 598 | ✗ | shard_config["base_kv_config"].is_object()) { | |
| 599 | ✗ | auto& base_kv_config = shard_config["base_kv_config"]; | |
| 600 | ✗ | AppendShardSuffixIfPresent(base_kv_config, "path", shard); | |
| 601 | ✗ | AppendShardSuffixIfPresent(base_kv_config, "rocksdb_path", shard); | |
| 602 | ✗ | AppendShardSuffixToNestedFilePaths(base_kv_config, shard); | |
| 603 | ✗ | LOG(INFO) << "gRPC shard " << shard | |
| 604 | ✗ | << " using base_kv_config: " << base_kv_config.dump(); | |
| 605 | } | ||
| 606 | |||
| 607 | ✗ | auto cache_ps = std::make_unique<CachePS>(shard_config); | |
| 608 | ✗ | ParameterServiceImpl service(cache_ps.get()); | |
| 609 | |||
| 610 | ✗ | grpc::EnableDefaultHealthCheckService(true); | |
| 611 | ✗ | grpc::reflection::InitProtoReflectionServerBuilderPlugin(); | |
| 612 | ✗ | ServerBuilder builder; | |
| 613 | ✗ | builder.AddListeningPort( | |
| 614 | ✗ | server_address, grpc::InsecureServerCredentials()); | |
| 615 | ✗ | builder.RegisterService(&service); | |
| 616 | ✗ | builder.SetMaxReceiveMessageSize(-1); // Unlimited | |
| 617 | ✗ | builder.SetMaxSendMessageSize(-1); // Unlimited | |
| 618 | ✗ | std::unique_ptr<Server> server(builder.BuildAndStart()); | |
| 619 | |||
| 620 | ✗ | if (!server) { | |
| 621 | std::string err_msg = fmt::format( | ||
| 622 | "FATAL: Failed to start gRPC server shard {} " | ||
| 623 | "on {}. " | ||
| 624 | "Port might be in use or invalid " | ||
| 625 | "configuration. " | ||
| 626 | "Check if port {} is already occupied.", | ||
| 627 | shard, | ||
| 628 | server_address, | ||
| 629 | ✗ | port); | |
| 630 | ✗ | std::cerr << err_msg << std::endl; | |
| 631 | ✗ | LOG(FATAL) << err_msg; | |
| 632 | return; | ||
| 633 | ✗ | } | |
| 634 | ✗ | std::cout << "Server shard " << shard << " listening on " | |
| 635 | ✗ | << server_address << std::endl; | |
| 636 | ✗ | server->Wait(); | |
| 637 | ✗ | } catch (const std::exception& e) { | |
| 638 | std::cerr << "FATAL: Uncaught exception in shard thread: " | ||
| 639 | ✗ | << e.what() << std::endl; | |
| 640 | ✗ | LOG(FATAL) << "Uncaught exception in shard thread: " << e.what(); | |
| 641 | ✗ | } catch (...) { | |
| 642 | ✗ | std::cerr << "FATAL: Unknown exception in shard thread" | |
| 643 | ✗ | << std::endl; | |
| 644 | ✗ | LOG(FATAL) << "Unknown exception in shard thread"; | |
| 645 | ✗ | } | |
| 646 | }); | ||
| 647 | } | ||
| 648 | |||
| 649 | // Wait for all server threads | ||
| 650 | ✗ | for (auto& t : server_threads) { | |
| 651 | ✗ | t.join(); | |
| 652 | } | ||
| 653 | ✗ | } else { | |
| 654 | // Single-server startup | ||
| 655 | ✗ | std::cout << "Starting single parameter server" << std::endl; | |
| 656 | ✗ | std::string server_address("0.0.0.0:15000"); | |
| 657 | ✗ | auto cache_ps = std::make_unique<CachePS>(config_["cache_ps"]); | |
| 658 | ✗ | ParameterServiceImpl service(cache_ps.get()); | |
| 659 | |||
| 660 | ✗ | std::atomic<bool> metrics_running{true}; | |
| 661 | ✗ | std::thread metrics_thread([&service, &metrics_running]() { | |
| 662 | ✗ | while (metrics_running) { | |
| 663 | ✗ | std::this_thread::sleep_for(std::chrono::seconds(10)); | |
| 664 | ✗ | service.PrintMetrics(); | |
| 665 | ✗ | service.ResetMetrics(); | |
| 666 | } | ||
| 667 | ✗ | }); | |
| 668 | |||
| 669 | ✗ | grpc::EnableDefaultHealthCheckService(true); | |
| 670 | ✗ | grpc::reflection::InitProtoReflectionServerBuilderPlugin(); | |
| 671 | ✗ | ServerBuilder builder; | |
| 672 | ✗ | builder.AddListeningPort( | |
| 673 | ✗ | server_address, grpc::InsecureServerCredentials()); | |
| 674 | ✗ | builder.RegisterService(&service); | |
| 675 | ✗ | builder.SetMaxReceiveMessageSize(-1); // Unlimited | |
| 676 | ✗ | builder.SetMaxSendMessageSize(-1); // Unlimited | |
| 677 | ✗ | std::unique_ptr<Server> server(builder.BuildAndStart()); | |
| 678 | ✗ | std::cerr << "sever built succesfully" << std::endl; | |
| 679 | ✗ | if (!server) { | |
| 680 | std::string err_msg = fmt::format( | ||
| 681 | "FATAL: Failed to start gRPC server on {}. " | ||
| 682 | "Port might be in use or invalid configuration.", | ||
| 683 | ✗ | server_address); | |
| 684 | ✗ | std::cerr << err_msg << std::endl; | |
| 685 | ✗ | LOG(FATAL) << err_msg; | |
| 686 | metrics_running = false; | ||
| 687 | if (metrics_thread.joinable()) { | ||
| 688 | metrics_thread.join(); | ||
| 689 | } | ||
| 690 | return; | ||
| 691 | ✗ | } | |
| 692 | ✗ | std::cout << "Server listening on " << server_address << std::endl; | |
| 693 | ✗ | server->Wait(); | |
| 694 | |||
| 695 | ✗ | metrics_running = false; | |
| 696 | ✗ | if (metrics_thread.joinable()) { | |
| 697 | ✗ | metrics_thread.join(); | |
| 698 | } | ||
| 699 | ✗ | } | |
| 700 | } | ||
| 701 | }; | ||
| 702 | |||
| 703 | FACTORY_REGISTER(BaseParameterServer, GRPCParameterServer, GRPCParameterServer); | ||
| 704 | |||
| 705 | } // namespace recstore | ||
| 706 | |||
| 707 | #ifndef RECSTORE_NO_SERVER_MAIN | ||
| 708 | ✗ | int main(int argc, char** argv) { | |
| 709 | ✗ | base::Init(&argc, &argv); | |
| 710 | ✗ | xmh::Reporter::StartReportThread(2000); | |
| 711 | const std::string config_path = | ||
| 712 | ✗ | FLAGS_config_path.empty() | |
| 713 | ✗ | ? base::ResolveRecStoreConfigPath().string() | |
| 714 | ✗ | : FLAGS_config_path; | |
| 715 | ✗ | std::ifstream config_file(config_path); | |
| 716 | ✗ | if (!config_file.is_open()) { | |
| 717 | ✗ | throw std::runtime_error("Cannot open config file: " + config_path); | |
| 718 | } | ||
| 719 | ✗ | nlohmann::json ex; | |
| 720 | ✗ | config_file >> ex; | |
| 721 | ✗ | recstore::GRPCParameterServer ps; | |
| 722 | ✗ | std::cout << "Parameter server config: " << ex.dump(2) << std::endl; | |
| 723 | ✗ | ps.Init(ex); | |
| 724 | ✗ | ps.Run(); | |
| 725 | ✗ | return 0; | |
| 726 | ✗ | } | |
| 727 | #endif | ||
| 728 |