GCC Code Coverage Report


Directory: src/
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 0.3% 2 / 0 / 691
Functions: 2.4% 1 / 0 / 42
Branches: 0.2% 2 / 0 / 892

ps/rdma/rdma_ps_client_adapter.cc
Line Branch Exec Source
1 #include "ps/rdma/rdma_ps_client_adapter.h"
2
3 #include <algorithm>
4 #include <atomic>
5 #include <chrono>
6 #include <cstdint>
7 #include <cstdlib>
8 #include <cstring>
9 #include <fstream>
10 #include <limits>
11 #include <memory>
12 #include <stdexcept>
13 #include <thread>
14 #include <utility>
15
16 #include <folly/portability/GFlags.h>
17 #include <folly/init/Init.h>
18
19 #include "framework/common/ps_client_config_adapter.h"
20 #include "ps/base/config.h"
21 #include "base/hash.h"
22 #include "ps/rdma/rdma_common.h"
23 #include "ps/rdma/rc_options.h"
24
25 DECLARE_int32(global_id);
26 DECLARE_int32(num_server_processes);
27 DECLARE_int32(num_client_processes);
28 DECLARE_int32(value_size);
29 DECLARE_int32(max_kv_num_per_request);
30 DECLARE_string(rdma_transport_mode);
31 DEFINE_string(rdma_transport_mode, "rc_write", "RDMA transport mode: rc_write");
32 DEFINE_bool(rdma_adapter_skip_prefetch_result_copy,
33 false,
34 "Benchmark-only option to skip copying RDMA prefetch results into "
35 "the GetPrefetchResultFlat output vector");
36
37 namespace recstore {
38
39 namespace {
40 bool AdapterProfileEnabled() {
41 const char* value = std::getenv("RECSTORE_RDMA_ADAPTER_PROFILE");
42 return value != nullptr && std::string(value) != "0";
43 }
44
45 void SetIntFlagFromEnv(const char* env_name, int32_t* flag_value) {
46 const char* value = std::getenv(env_name);
47 if (value == nullptr || *value == '\0') {
48 return;
49 }
50 char* end = nullptr;
51 const long parsed = std::strtol(value, &end, 10);
52 if (end != value && *end == '\0') {
53 *flag_value = static_cast<int32_t>(parsed);
54 }
55 }
56
57 void ApplyRdmaFlagsFromEnv() {
58 if (const char* value = std::getenv("RECSTORE_RDMA_RC_NAMESPACE")) {
59 FLAGS_rdma_rc_namespace = value;
60 }
61 if (const char* value = std::getenv("RECSTORE_RDMA_CONTROL_PLANE_HOST")) {
62 FLAGS_rdma_control_plane_host = value;
63 }
64 SetIntFlagFromEnv(
65 "RECSTORE_RDMA_CONTROL_PLANE_PORT", &FLAGS_rdma_control_plane_port);
66 SetIntFlagFromEnv(
67 "RECSTORE_RDMA_WAIT_TIMEOUT_MS", &FLAGS_rdma_wait_timeout_ms);
68 SetIntFlagFromEnv("RECSTORE_RDMA_RC_QPS_PER_CLIENT_PER_SHARD",
69 &FLAGS_rdma_rc_qps_per_client_per_shard);
70 SetIntFlagFromEnv(
71 "RECSTORE_RDMA_RC_SLOTS_PER_QP", &FLAGS_rdma_rc_slots_per_qp);
72 SetIntFlagFromEnv("RECSTORE_RDMA_RC_SERVER_COROUTINES_PER_THREAD",
73 &FLAGS_rdma_rc_server_coroutines_per_thread);
74 SetIntFlagFromEnv(
75 "RECSTORE_RDMA_RC_SERVER_GET_WORKERS", &FLAGS_rdma_rc_server_get_workers);
76 }
77
78 std::int64_t NsSince(std::chrono::steady_clock::time_point start,
79 std::chrono::steady_clock::time_point end) {
80 return std::chrono::duration_cast<std::chrono::nanoseconds>(end - start)
81 .count();
82 }
83
84 int ValueSizeHintFromBaseKvConfig(const json& base_kv_config,
85 int fallback_value_size) {
86 if (!base_kv_config.is_object()) {
87 return fallback_value_size;
88 }
89 if (!base_kv_config.contains("value") ||
90 !base_kv_config["value"].is_object()) {
91 return fallback_value_size;
92 }
93 return base_kv_config["value"].value(
94 "default_value_size_hint", fallback_value_size);
95 }
96
97 std::vector<std::string> ReadProcessArgv() {
98 std::ifstream cmdline("/proc/self/cmdline", std::ios::binary);
99 std::vector<std::string> argv;
100 if (!cmdline.is_open()) {
101 return argv;
102 }
103
104 std::string current;
105 char ch = '\0';
106 while (cmdline.get(ch)) {
107 if (ch == '\0') {
108 if (!current.empty()) {
109 argv.push_back(current);
110 current.clear();
111 }
112 continue;
113 }
114 current.push_back(ch);
115 }
116 if (!current.empty()) {
117 argv.push_back(current);
118 }
119 return argv;
120 }
121 } // namespace
122
123 int RDMAPSClientAdapter::PartitionKey(uint64_t key) const {
124 CHECK_GT(num_shards_, 0);
125 if (hash_method_ == "city_hash") {
126 return static_cast<int>(GetHash(key) % static_cast<uint64_t>(num_shards_));
127 }
128 if (hash_method_ == "simple_mod") {
129 return static_cast<int>(key % static_cast<uint64_t>(num_shards_));
130 }
131 throw std::runtime_error("unsupported shard hash method: " + hash_method_);
132 }
133
134 std::vector<RDMAPSClientAdapter::ShardChunk>
135 RDMAPSClientAdapter::BuildChunks(base::ConstArray<uint64_t> keys) const {
136 std::vector<std::vector<uint64_t>> shard_keys(num_shards_);
137 std::vector<std::vector<std::size_t>> shard_positions(num_shards_);
138
139 for (std::size_t i = 0; i < keys.Size(); ++i) {
140 const int shard = PartitionKey(keys[i]);
141 shard_keys[static_cast<std::size_t>(shard)].push_back(keys[i]);
142 shard_positions[static_cast<std::size_t>(shard)].push_back(i);
143 }
144
145 std::vector<ShardChunk> chunks;
146 const std::size_t max_keys_per_rpc = MaxGetKeysPerRpc();
147 for (int shard = 0; shard < num_shards_; ++shard) {
148 const int client_index = shard_to_client_index_.at(shard);
149 for (std::size_t offset = 0;
150 offset < shard_keys[static_cast<std::size_t>(shard)].size();
151 offset += max_keys_per_rpc) {
152 const std::size_t end =
153 std::min(offset + max_keys_per_rpc,
154 shard_keys[static_cast<std::size_t>(shard)].size());
155 ShardChunk chunk;
156 chunk.shard_id = shard;
157 chunk.client_index = client_index;
158 chunk.keys.assign(
159 shard_keys[static_cast<std::size_t>(shard)].begin() + offset,
160 shard_keys[static_cast<std::size_t>(shard)].begin() + end);
161 chunk.positions.assign(
162 shard_positions[static_cast<std::size_t>(shard)].begin() + offset,
163 shard_positions[static_cast<std::size_t>(shard)].begin() + end);
164 chunks.push_back(std::move(chunk));
165 }
166 }
167 return chunks;
168 }
169
170 bool RDMAPSClientAdapter::FinalizeBatchIfNeeded(BatchRequest* batch) {
171 if (batch == nullptr) {
172 return false;
173 }
174 if (batch->assembled) {
175 return batch->status_code ==
176 static_cast<std::int32_t>(petps::RpcStatus::kOk);
177 }
178
179 batch->status_code = static_cast<std::int32_t>(petps::RpcStatus::kOk);
180 for (const auto& pending : batch->shard_rpcs) {
181 const auto* status_word = petps::FixedSlotStatusWord(
182 pending.recv_buffer, pending.key_count, FLAGS_value_size);
183 if (*status_word != static_cast<std::int32_t>(petps::RpcStatus::kOk)) {
184 batch->status_code = *status_word;
185 break;
186 }
187 }
188
189 const int embedding_dim = FLAGS_value_size / sizeof(float);
190 if (batch->status_code == static_cast<std::int32_t>(petps::RpcStatus::kOk)) {
191 for (const auto& pending : batch->shard_rpcs) {
192 const float* shard_values =
193 static_cast<const float*>(pending.recv_buffer);
194 for (std::size_t i = 0; i < pending.original_positions.size(); ++i) {
195 std::memcpy(
196 batch->user_buffer + pending.original_positions[i] * embedding_dim,
197 shard_values + i * embedding_dim,
198 FLAGS_value_size);
199 }
200 }
201 }
202
203 auto* batch_status_word = reinterpret_cast<std::int32_t*>(
204 reinterpret_cast<char*>(batch->user_buffer) +
205 batch->total_key_count * static_cast<std::size_t>(FLAGS_value_size));
206 *batch_status_word = batch->status_code;
207 batch->assembled = true;
208 return batch->status_code == static_cast<std::int32_t>(petps::RpcStatus::kOk);
209 }
210
211 void RDMAPSClientAdapter::WaitShardRpcsCooperatively(
212 const std::vector<PendingShardRpc>& shard_rpcs) {
213 std::vector<bool> finished(shard_rpcs.size(), false);
214 std::size_t remaining = shard_rpcs.size();
215 while (remaining > 0) {
216 bool made_progress = false;
217 for (std::size_t i = 0; i < shard_rpcs.size(); ++i) {
218 if (finished[i]) {
219 continue;
220 }
221 const auto& pending = shard_rpcs[i];
222 auto& client =
223 shard_clients_[static_cast<std::size_t>(pending.client_index)];
224 if (!client->QueryRPCFinished(pending.rpc_id)) {
225 continue;
226 }
227 client->WaitRPCFinish(pending.rpc_id);
228 finished[i] = true;
229 made_progress = true;
230 --remaining;
231 }
232 if (!made_progress) {
233 std::this_thread::yield();
234 }
235 }
236 }
237
238 void InitializeRdmaProcessRuntime() {
239 static std::once_flag init_once;
240 std::call_once(init_once, []() {
241 // Python entrypoints pass application CLI flags that are not gflags.
242 // Passing them to folly::init makes gflags abort before the RDMA client can
243 // start.
244 std::vector<std::string> argv_strings = {"recstore_rdma_client"};
245 std::vector<char*> argv_storage;
246 argv_storage.reserve(argv_strings.size() + 1);
247 for (auto& arg : argv_strings) {
248 argv_storage.push_back(arg.data());
249 }
250 argv_storage.push_back(nullptr);
251
252 int argc = static_cast<int>(argv_strings.size());
253 char** argv = argv_storage.data();
254 folly::init(&argc, &argv);
255 ApplyRdmaFlagsFromEnv();
256 });
257 }
258
259 4 RDMAPSClientAdapter::RDMAPSClientAdapter(json config)
260
2/4
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 14 taken 4 times.
✗ Branch 15 not taken.
4 : BasePSClient(config), config_(std::move(config)) {}
261
262 void RDMAPSClientAdapter::EnsureClientInitialized() {
263 std::lock_guard<std::mutex> guard(init_mu_);
264 if (initialized_) {
265 return;
266 }
267
268 const json cache_ps_cfg =
269 config_.contains("cache_ps") ? config_["cache_ps"] : json::object();
270 const json client_cfg =
271 config_.contains("client") ? config_["client"] : json::object();
272 const json dist_cfg = ResolveFrameworkDistributedClientConfig(config_);
273 const int logical_client_id =
274 config_.value("rdma_logical_client_id", FLAGS_rdma_rc_client_id_base);
275
276 num_shards_ = dist_cfg.value("num_shards", 1);
277 hash_method_ = dist_cfg.value("hash_method", "city_hash");
278 if (FLAGS_global_id < num_shards_) {
279 FLAGS_num_server_processes = num_shards_;
280 FLAGS_num_client_processes = 1;
281 FLAGS_global_id = num_shards_;
282 } else if (FLAGS_num_server_processes != num_shards_) {
283 throw std::runtime_error(
284 "RDMA num_server_processes must match distributed_client.num_shards");
285 }
286 FLAGS_value_size =
287 cache_ps_cfg.contains("base_kv_config")
288 ? ValueSizeHintFromBaseKvConfig(
289 cache_ps_cfg["base_kv_config"], FLAGS_value_size)
290 : FLAGS_value_size;
291 FLAGS_max_kv_num_per_request =
292 dist_cfg.value("max_keys_per_request", FLAGS_max_kv_num_per_request);
293 if (const char* mode = std::getenv("RECSTORE_RDMA_TRANSPORT_MODE")) {
294 FLAGS_rdma_transport_mode = mode;
295 }
296
297 shard_clients_.clear();
298 shard_to_client_index_.clear();
299 client_ = nullptr;
300
301 if (num_shards_ <= 1) {
302 shard_clients_.push_back(std::make_unique<petps::PetPSClient>(
303 client_cfg.value("host", std::string("127.0.0.1")),
304 client_cfg.value("port", 25000),
305 client_cfg.value("shard", 0),
306 logical_client_id));
307 client_ = shard_clients_.front().get();
308 shard_to_client_index_[0] = 0;
309 } else {
310 const auto servers_it = dist_cfg.find("servers");
311 if (servers_it == dist_cfg.end() || !servers_it->is_array() ||
312 servers_it->empty()) {
313 throw std::runtime_error(
314 "RDMA distributed_client.servers must be provided for multi-shard "
315 "configuration");
316 }
317
318 CHECK_EQ(static_cast<int>(servers_it->size()), num_shards_)
319 << "RDMA distributed_client.servers size must equal num_shards";
320 for (const auto& server : *servers_it) {
321 const int shard = server.value("shard", -1);
322 if (shard < 0) {
323 throw std::runtime_error(
324 "RDMA distributed_client.servers[].shard must be explicit");
325 }
326 shard_clients_.push_back(std::make_unique<petps::PetPSClient>(
327 server.value("host", std::string("127.0.0.1")),
328 server.value("port", 25000),
329 shard,
330 logical_client_id));
331 shard_to_client_index_[shard] =
332 static_cast<int>(shard_clients_.size() - 1);
333 }
334 }
335
336 initialized_ = true;
337 }
338
339 void RDMAPSClientAdapter::EnsureThreadInitialized() {
340 EnsureClientInitialized();
341 const std::thread::id tid = std::this_thread::get_id();
342 std::lock_guard<std::mutex> guard(thread_init_mu_);
343 if (initialized_threads_.find(tid) != initialized_threads_.end()) {
344 return;
345 }
346
347 if (num_shards_ <= 1) {
348 if (client_ != nullptr) {
349 client_->InitThread();
350 }
351 } else {
352 for (auto& shard_client : shard_clients_) {
353 shard_client->InitThread();
354 }
355 }
356
357 initialized_threads_.insert(tid);
358 }
359
360 void RDMAPSClientAdapter::EnsureTableReady(const std::string& table_name,
361 int64_t embedding_dim) {
362 std::lock_guard<std::mutex> guard(state_mu_);
363 const auto it = tables_.find(table_name);
364 if (it == tables_.end()) {
365 throw std::runtime_error("RDMA table is not initialized: " + table_name);
366 }
367 if (static_cast<int64_t>(it->second.config.embedding_dim) != embedding_dim) {
368 throw std::runtime_error(
369 "RDMA embedding dimension mismatch for table " + table_name);
370 }
371 }
372
373 int64_t RDMAPSClientAdapter::DefaultEmbeddingDimOrThrow() const {
374 if (tables_.empty()) {
375 throw std::runtime_error(
376 "RDMA table metadata is empty; call InitEmbeddingTable first");
377 }
378 return static_cast<int64_t>(tables_.begin()->second.config.embedding_dim);
379 }
380
381 std::size_t RDMAPSClientAdapter::MaxGetKeysPerRpc() const {
382 const std::size_t response_limited = petps::GetKeysPerRpcByResponseBudget(
383 static_cast<std::size_t>(FLAGS_value_size),
384 static_cast<std::size_t>(FLAGS_rdma_rc_mtu_bytes),
385 static_cast<std::size_t>(FLAGS_rdma_rc_target_response_mtu));
386 const std::size_t request_limited =
387 petps::PutPayloadBudget(
388 static_cast<std::size_t>(FLAGS_rdma_rc_request_slot_bytes)) /
389 sizeof(std::uint64_t);
390 std::size_t limit = static_cast<std::size_t>(FLAGS_max_kv_num_per_request);
391 if (response_limited > 0) {
392 limit = std::min(limit, response_limited);
393 }
394 if (request_limited > 0) {
395 limit = std::min(limit, request_limited);
396 }
397 return std::max<std::size_t>(limit, 1);
398 }
399
400 std::size_t RDMAPSClientAdapter::MaxInFlightGetRpcs() const {
401 const std::size_t qps = static_cast<std::size_t>(
402 std::max(FLAGS_rdma_rc_qps_per_client_per_shard, 1));
403 const std::size_t slots =
404 static_cast<std::size_t>(std::max(FLAGS_rdma_rc_slots_per_qp, 1));
405 return std::max<std::size_t>(qps * slots, 1);
406 }
407
408 RDMAPSClientAdapter::PrefetchState
409 RDMAPSClientAdapter::GetPrefetchState(uint64_t prefetch_id) {
410 std::lock_guard<std::mutex> guard(state_mu_);
411 const auto it = prefetches_.find(prefetch_id);
412 if (it == prefetches_.end()) {
413 throw std::runtime_error(
414 "Unknown RDMA prefetch id: " + std::to_string(prefetch_id));
415 }
416 return it->second;
417 }
418
419 void RDMAPSClientAdapter::MarkPrefetchConsumed(uint64_t prefetch_id) {
420 std::lock_guard<std::mutex> guard(state_mu_);
421 const auto it = prefetches_.find(prefetch_id);
422 if (it == prefetches_.end()) {
423 return;
424 }
425 free_prefetch_buffer_ids_.push_back(it->second.buffer_id);
426 prefetches_.erase(it);
427 }
428
429 bool RDMAPSClientAdapter::QueryRPCFinished(int rpc_id) {
430 if (rpc_id >= 0 && num_shards_ <= 1) {
431 return client_ != nullptr ? client_->QueryRPCFinished(rpc_id) : true;
432 }
433
434 std::lock_guard<std::mutex> guard(batches_mu_);
435 auto it = batches_.find(rpc_id);
436 CHECK(it != batches_.end());
437
438 for (const auto& pending : it->second.shard_rpcs) {
439 if (!shard_clients_[static_cast<std::size_t>(pending.client_index)]
440 ->QueryRPCFinished(pending.rpc_id)) {
441 return false;
442 }
443 }
444
445 return FinalizeBatchIfNeeded(&it->second);
446 }
447
448 void RDMAPSClientAdapter::WaitRPCFinish(int rpc_id) {
449 if (rpc_id >= 0 && num_shards_ <= 1) {
450 if (client_ != nullptr) {
451 client_->WaitRPCFinish(rpc_id);
452 }
453 return;
454 }
455
456 std::vector<PendingShardRpc> shard_rpcs;
457 {
458 std::lock_guard<std::mutex> guard(batches_mu_);
459 auto it = batches_.find(rpc_id);
460 CHECK(it != batches_.end());
461 if (it->second.assembled) {
462 return;
463 }
464 shard_rpcs = it->second.shard_rpcs;
465 }
466
467 WaitShardRpcsCooperatively(shard_rpcs);
468
469 {
470 std::lock_guard<std::mutex> guard(batches_mu_);
471 auto it = batches_.find(rpc_id);
472 CHECK(it != batches_.end());
473 FinalizeBatchIfNeeded(&it->second);
474 }
475 }
476
477 void RDMAPSClientAdapter::RevokeRPCResource(int rpc_id) {
478 if (rpc_id >= 0 && num_shards_ <= 1) {
479 if (client_ != nullptr) {
480 client_->RevokeRPCResource(rpc_id);
481 }
482 return;
483 }
484
485 std::lock_guard<std::mutex> guard(batches_mu_);
486 auto it = batches_.find(rpc_id);
487 CHECK(it != batches_.end());
488
489 for (const auto& pending : it->second.shard_rpcs) {
490 shard_clients_[static_cast<std::size_t>(pending.client_index)]
491 ->RevokeRPCResource(pending.rpc_id);
492 }
493
494 batches_.erase(it);
495 }
496
497 float* RDMAPSClientAdapter::AcquirePrefetchBuffer(std::size_t bytes,
498 std::size_t* buffer_id) {
499 if (buffer_id == nullptr) {
500 throw std::invalid_argument("RDMA prefetch buffer_id must not be null");
501 }
502 std::lock_guard<std::mutex> guard(state_mu_);
503 while (!free_prefetch_buffer_ids_.empty()) {
504 const std::size_t id = free_prefetch_buffer_ids_.back();
505 free_prefetch_buffer_ids_.pop_back();
506 if (id < prefetch_buffer_capacities_.size() &&
507 prefetch_buffer_capacities_[id] >= bytes) {
508 *buffer_id = id;
509 return reinterpret_cast<float*>(prefetch_buffers_[id].get());
510 }
511 }
512
513 const std::size_t id = prefetch_buffers_.size();
514 prefetch_buffers_.push_back(std::make_unique<char[]>(bytes));
515 prefetch_buffer_capacities_.push_back(bytes);
516 *buffer_id = id;
517 return reinterpret_cast<float*>(prefetch_buffers_.back().get());
518 }
519
520 void RDMAPSClientAdapter::ReleasePrefetchBuffer(std::size_t buffer_id) {
521 std::lock_guard<std::mutex> guard(state_mu_);
522 if (buffer_id < prefetch_buffers_.size()) {
523 free_prefetch_buffer_ids_.push_back(buffer_id);
524 }
525 }
526
527 const float* RDMAPSClientAdapter::BorrowPrefetchResult(
528 const PrefetchState& state,
529 std::int32_t* status_code,
530 std::size_t* response_bytes) {
531 if (!state.borrowed_response || client_ == nullptr) {
532 return nullptr;
533 }
534 auto* pet_client = dynamic_cast<petps::PetPSClient*>(client_);
535 if (pet_client == nullptr) {
536 return nullptr;
537 }
538 std::size_t key_count = 0;
539 const float* payload = pet_client->BorrowGetResultPayload(
540 state.rpc_id, &key_count, response_bytes, status_code);
541 if (payload == nullptr ||
542 key_count != static_cast<std::size_t>(state.key_count)) {
543 return nullptr;
544 }
545 return payload;
546 }
547
548 int RDMAPSClientAdapter::SubmitGetParameter(
549 base::ConstArray<uint64_t> keys,
550 float* values,
551 bool isAsync,
552 int async_req_id) {
553 EnsureThreadInitialized();
554 if (keys.Size() == 0) {
555 auto* status =
556 reinterpret_cast<std::int32_t*>(reinterpret_cast<char*>(values));
557 *status = static_cast<std::int32_t>(petps::RpcStatus::kOk);
558 return 0;
559 }
560
561 BatchRequest batch;
562 batch.user_buffer = values;
563 batch.total_key_count = keys.Size();
564 auto* batch_status_word =
565 petps::FixedSlotStatusWord(values, keys.Size(), FLAGS_value_size);
566 *batch_status_word = static_cast<std::int32_t>(petps::RpcStatus::kPending);
567
568 if (num_shards_ <= 1) {
569 if (client_ == nullptr) {
570 return -1;
571 }
572 const std::size_t max_keys_per_rpc = MaxGetKeysPerRpc();
573 const std::size_t max_in_flight = MaxInFlightGetRpcs();
574 const std::size_t total_keys = keys.Size();
575 if (total_keys <= max_keys_per_rpc) {
576 return client_->GetParameter(keys, values, isAsync, async_req_id);
577 }
578 std::vector<PendingShardRpc> window;
579 window.reserve(max_in_flight);
580 auto drain_and_release_window = [this, &window, &batch]() {
581 // Large model batches can split into more GET RPCs than the RC slot pool.
582 // Keep submission bounded by waiting and freeing each window before
583 // acquiring more slots.
584 for (const auto& pending : window) {
585 client_->WaitRPCFinish(pending.rpc_id);
586 }
587 for (const auto& pending : window) {
588 batch.shard_rpcs.push_back(pending);
589 client_->RevokeRPCResource(pending.rpc_id);
590 }
591 window.clear();
592 };
593 for (std::size_t offset = 0; offset < total_keys;
594 offset += max_keys_per_rpc) {
595 const std::size_t end = std::min(offset + max_keys_per_rpc, total_keys);
596 std::vector<uint64_t> key_slice;
597 key_slice.reserve(end - offset);
598 std::vector<std::size_t> positions;
599 positions.reserve(end - offset);
600 for (std::size_t i = offset; i < end; ++i) {
601 key_slice.push_back(keys[i]);
602 positions.push_back(i);
603 }
604 void* recv = client_->GetReceiveBuffer(
605 key_slice.size() * static_cast<std::size_t>(FLAGS_value_size) +
606 sizeof(std::int32_t));
607 const int rpc_id = client_->GetParameter(
608 base::ConstArray<uint64_t>(key_slice),
609 static_cast<float*>(recv),
610 isAsync,
611 async_req_id);
612 window.push_back(PendingShardRpc{
613 0,
614 0,
615 rpc_id,
616 std::move(positions),
617 recv,
618 key_slice.size(),
619 });
620 if (window.size() >= max_in_flight) {
621 drain_and_release_window();
622 }
623 }
624 if (!window.empty()) {
625 drain_and_release_window();
626 }
627 } else {
628 const std::size_t max_in_flight = MaxInFlightGetRpcs();
629 std::vector<PendingShardRpc> window;
630 window.reserve(max_in_flight);
631 auto drain_and_release_window = [this, &window, &batch]() {
632 WaitShardRpcsCooperatively(window);
633 for (const auto& pending : window) {
634 batch.shard_rpcs.push_back(pending);
635 shard_clients_[static_cast<std::size_t>(pending.client_index)]
636 ->RevokeRPCResource(pending.rpc_id);
637 }
638 window.clear();
639 };
640 for (const auto& chunk : BuildChunks(keys)) {
641 BaseParameterClient* client = shard_clients_[chunk.client_index].get();
642 void* recv = client->GetReceiveBuffer(
643 chunk.keys.size() * static_cast<std::size_t>(FLAGS_value_size) +
644 sizeof(std::int32_t));
645 const int rpc_id = client->GetParameter(
646 base::ConstArray<uint64_t>(chunk.keys),
647 static_cast<float*>(recv),
648 isAsync,
649 async_req_id);
650 window.push_back(PendingShardRpc{
651 chunk.shard_id,
652 chunk.client_index,
653 rpc_id,
654 chunk.positions,
655 recv,
656 chunk.keys.size(),
657 });
658 if (window.size() >= max_in_flight) {
659 drain_and_release_window();
660 }
661 }
662 if (!window.empty()) {
663 drain_and_release_window();
664 }
665 }
666
667 int batch_id = 0;
668 {
669 std::lock_guard<std::mutex> guard(batches_mu_);
670 batch_id = batch_rpc_id_acc_--;
671 if (batch_id >= 0) {
672 throw std::runtime_error("rdma batch rpc id exhausted negative range");
673 }
674 batches_[batch_id] = std::move(batch);
675 }
676 if (!isAsync) {
677 WaitRPCFinish(batch_id);
678 }
679 return batch_id;
680 }
681
682 int RDMAPSClientAdapter::GetParameter(const base::ConstArray<uint64_t>& keys,
683 float* values) {
684 EnsureThreadInitialized();
685 if (keys.Size() == 0) {
686 return 0;
687 }
688
689 const std::size_t response_bytes =
690 petps::FixedSlotResponseBytes(keys.Size(), FLAGS_value_size);
691 float* recv = nullptr;
692 if (num_shards_ > 1) {
693 if (shard_clients_.empty()) {
694 return -1;
695 }
696 recv = static_cast<float*>(
697 shard_clients_.front()->GetReceiveBuffer(response_bytes));
698 } else {
699 recv = static_cast<float*>(client_->GetReceiveBuffer(response_bytes));
700 }
701
702 const int rpc_id = SubmitGetParameter(keys, recv, false, 0);
703 WaitRPCFinish(rpc_id);
704 const auto* status_word =
705 petps::FixedSlotStatusWord(recv, keys.Size(), FLAGS_value_size);
706 if (*status_word != static_cast<std::int32_t>(petps::RpcStatus::kOk)) {
707 RevokeRPCResource(rpc_id);
708 return -1;
709 }
710
711 std::memcpy(
712 values, recv, keys.Size() * static_cast<std::size_t>(FLAGS_value_size));
713 RevokeRPCResource(rpc_id);
714 return 0;
715 }
716
717 int RDMAPSClientAdapter::PutParameter(
718 const base::ConstArray<uint64_t>& keys,
719 const std::vector<std::vector<float>>& values) {
720 EnsureThreadInitialized();
721 if (num_shards_ <= 1) {
722 if (client_ == nullptr) {
723 return -1;
724 }
725 return client_->PutParameter(keys.ToVector(), values);
726 }
727 if (keys.Size() != values.size()) {
728 return -1;
729 }
730 if (keys.Size() == 0) {
731 return 0;
732 }
733
734 std::vector<std::vector<uint64_t>> shard_keys(num_shards_);
735 std::vector<std::vector<std::vector<float>>> shard_values(num_shards_);
736
737 for (std::size_t i = 0; i < keys.Size(); ++i) {
738 const int shard = PartitionKey(keys[i]);
739 shard_keys[static_cast<std::size_t>(shard)].push_back(keys[i]);
740 shard_values[static_cast<std::size_t>(shard)].push_back(values[i]);
741 }
742
743 for (int shard = 0; shard < num_shards_; ++shard) {
744 const int client_index = shard_to_client_index_.at(shard);
745 for (std::size_t offset = 0;
746 offset < shard_keys[static_cast<std::size_t>(shard)].size();
747 offset += static_cast<std::size_t>(FLAGS_max_kv_num_per_request)) {
748 const std::size_t end = std::min(
749 offset + static_cast<std::size_t>(FLAGS_max_kv_num_per_request),
750 shard_keys[static_cast<std::size_t>(shard)].size());
751 std::vector<uint64_t> key_slice(
752 shard_keys[static_cast<std::size_t>(shard)].begin() + offset,
753 shard_keys[static_cast<std::size_t>(shard)].begin() + end);
754 std::vector<std::vector<float>> value_slice(
755 shard_values[static_cast<std::size_t>(shard)].begin() + offset,
756 shard_values[static_cast<std::size_t>(shard)].begin() + end);
757 int rc =
758 shard_clients_[static_cast<std::size_t>(client_index)]->PutParameter(
759 key_slice, value_slice);
760 if (rc != 0) {
761 return rc;
762 }
763 }
764 }
765 return 0;
766 }
767
768 int RDMAPSClientAdapter::UpdateParameter(
769 const std::string& table_name,
770 const base::ConstArray<uint64_t>& keys,
771 const std::vector<std::vector<float>>* grads) {
772 if (grads == nullptr) {
773 return -1;
774 }
775 if (grads->empty()) {
776 return 0;
777 }
778 EnsureThreadInitialized();
779 if (num_shards_ <= 1) {
780 if (client_ == nullptr) {
781 return -1;
782 }
783 return client_->UpdateParameter(table_name, keys, grads);
784 }
785 if (keys.Size() != grads->size()) {
786 return -1;
787 }
788 if (keys.Size() == 0) {
789 return 0;
790 }
791
792 std::vector<std::vector<uint64_t>> shard_keys(num_shards_);
793 std::vector<std::vector<std::vector<float>>> shard_grads(num_shards_);
794
795 for (std::size_t i = 0; i < keys.Size(); ++i) {
796 const int shard = PartitionKey(keys[i]);
797 shard_keys[static_cast<std::size_t>(shard)].push_back(keys[i]);
798 shard_grads[static_cast<std::size_t>(shard)].push_back((*grads)[i]);
799 }
800
801 for (int shard = 0; shard < num_shards_; ++shard) {
802 if (shard_keys[static_cast<std::size_t>(shard)].empty()) {
803 continue;
804 }
805 const int client_index = shard_to_client_index_.at(shard);
806 const int rc =
807 shard_clients_[static_cast<std::size_t>(client_index)]->UpdateParameter(
808 table_name,
809 base::ConstArray<uint64_t>(
810 shard_keys[static_cast<std::size_t>(shard)]),
811 &shard_grads[static_cast<std::size_t>(shard)]);
812 if (rc != 0) {
813 return rc;
814 }
815 }
816 return 0;
817 }
818
819 int RDMAPSClientAdapter::UpdateParameterFlat(
820 const std::string& table_name,
821 const base::ConstArray<uint64_t>& keys,
822 const float* grads,
823 int64_t num_rows,
824 int64_t embedding_dim) {
825 EnsureTableReady(table_name, embedding_dim);
826 if (num_rows == 0) {
827 return 0;
828 }
829 if (grads == nullptr) {
830 return -1;
831 }
832 if (keys.Size() != static_cast<std::size_t>(num_rows)) {
833 return -1;
834 }
835 std::vector<std::vector<float>> updated;
836 updated.reserve(static_cast<std::size_t>(num_rows));
837 for (int64_t row = 0; row < num_rows; ++row) {
838 std::vector<float> values(static_cast<std::size_t>(embedding_dim), 0.0f);
839 for (int64_t col = 0; col < embedding_dim; ++col) {
840 const std::size_t idx =
841 static_cast<std::size_t>(row * embedding_dim + col);
842 values[static_cast<std::size_t>(col)] = grads[idx];
843 }
844 updated.push_back(std::move(values));
845 }
846
847 return UpdateParameter(table_name, keys, &updated);
848 }
849
850 int RDMAPSClientAdapter::InitEmbeddingTable(
851 const std::string& table_name, const EmbeddingTableConfig& config) {
852 EnsureThreadInitialized();
853 if (num_shards_ <= 1) {
854 if (client_ == nullptr) {
855 return -1;
856 }
857 const int init_rc = client_->InitEmbeddingTable(
858 table_name, config.num_embeddings, config.embedding_dim);
859 if (init_rc != 0) {
860 return init_rc;
861 }
862 } else {
863 for (auto& shard_client : shard_clients_) {
864 const int rc = shard_client->InitEmbeddingTable(
865 table_name, config.num_embeddings, config.embedding_dim);
866 if (rc != 0) {
867 return rc;
868 }
869 }
870 }
871
872 std::lock_guard<std::mutex> guard(state_mu_);
873 const auto [it, inserted] = tables_.emplace(table_name, TableState{config});
874 if (!inserted) {
875 if (it->second.config.embedding_dim != config.embedding_dim ||
876 it->second.config.num_embeddings != config.num_embeddings) {
877 return -1;
878 }
879 }
880 return 0;
881 }
882
883 int RDMAPSClientAdapter::AsyncGetParameter(const base::ConstArray<uint64_t>&,
884 float*) {
885 throw std::runtime_error(
886 "RDMA adapter AsyncGetParameter not implemented yet");
887 }
888
889 void RDMAPSClientAdapter::Command(PSCommand) {
890 EnsureThreadInitialized();
891 if (num_shards_ <= 1) {
892 if (client_ == nullptr) {
893 throw std::runtime_error("RDMA adapter has no initialized client");
894 }
895 client_->Barrier("rdma_command", 0);
896 return;
897 }
898 if (shard_clients_.empty()) {
899 throw std::runtime_error("RDMA adapter has no initialized clients");
900 }
901 shard_clients_.front()->Barrier("rdma_command", 0);
902 }
903
904 uint64_t
905 RDMAPSClientAdapter::PrefetchParameter(const base::ConstArray<uint64_t>& keys) {
906 EnsureThreadInitialized();
907 if (keys.Size() == 0) {
908 throw std::invalid_argument("RDMA prefetch requires at least one key");
909 }
910
911 const int64_t embedding_dim = DefaultEmbeddingDimOrThrow();
912 const std::size_t response_bytes =
913 petps::FixedSlotResponseBytes(keys.Size(), FLAGS_value_size);
914 std::size_t buffer_id = 0;
915 const bool borrow_single_shard_response =
916 num_shards_ <= 1 && keys.Size() <= MaxGetKeysPerRpc();
917 const bool batch_response = !borrow_single_shard_response;
918 float* buffer = AcquirePrefetchBuffer(response_bytes, &buffer_id);
919 auto* status_word = petps::FixedSlotStatusWord(
920 buffer, static_cast<std::size_t>(keys.Size()), FLAGS_value_size);
921 *status_word = static_cast<std::int32_t>(petps::RpcStatus::kPending);
922
923 int rpc_id = -1;
924 try {
925 rpc_id = SubmitGetParameter(keys, buffer, true, 0);
926 } catch (...) {
927 ReleasePrefetchBuffer(buffer_id);
928 throw;
929 }
930
931 std::lock_guard<std::mutex> guard(state_mu_);
932 const uint64_t prefetch_id = next_prefetch_id_++;
933 prefetches_.emplace(
934 prefetch_id,
935 PrefetchState{
936 buffer,
937 buffer_id,
938 rpc_id,
939 static_cast<int64_t>(keys.Size()),
940 embedding_dim,
941 borrow_single_shard_response,
942 batch_response,
943 });
944 return prefetch_id;
945 }
946
947 bool RDMAPSClientAdapter::IsPrefetchDone(uint64_t prefetch_id) {
948 EnsureThreadInitialized();
949 const PrefetchState state = GetPrefetchState(prefetch_id);
950 return QueryRPCFinished(state.rpc_id);
951 }
952
953 void RDMAPSClientAdapter::WaitForPrefetch(uint64_t prefetch_id) {
954 EnsureThreadInitialized();
955 const PrefetchState state = GetPrefetchState(prefetch_id);
956 try {
957 WaitRPCFinish(state.rpc_id);
958 } catch (...) {
959 RevokeRPCResource(state.rpc_id);
960 MarkPrefetchConsumed(prefetch_id);
961 throw;
962 }
963 }
964
965 bool RDMAPSClientAdapter::GetPrefetchResult(
966 uint64_t prefetch_id, std::vector<std::vector<float>>* values) {
967 if (values == nullptr) {
968 return false;
969 }
970
971 const PrefetchState state = GetPrefetchState(prefetch_id);
972 std::vector<float> flat;
973 int64_t num_rows = 0;
974 if (!GetPrefetchResultFlat(
975 prefetch_id, &flat, &num_rows, state.embedding_dim)) {
976 return false;
977 }
978
979 petps::CopyFlatRowsToVectors(
980 flat.data(),
981 static_cast<std::size_t>(num_rows),
982 static_cast<std::size_t>(state.embedding_dim),
983 values);
984 return true;
985 }
986
987 bool RDMAPSClientAdapter::GetPrefetchResultFlat(
988 uint64_t prefetch_id,
989 std::vector<float>* values,
990 int64_t* num_rows,
991 int64_t embedding_dim) {
992 if (values == nullptr || num_rows == nullptr) {
993 return false;
994 }
995
996 const PrefetchState state = GetPrefetchState(prefetch_id);
997 if (embedding_dim != state.embedding_dim) {
998 return false;
999 }
1000
1001 const bool profile_enabled = AdapterProfileEnabled();
1002 const auto wait_begin = std::chrono::steady_clock::now();
1003 std::int32_t status_code = static_cast<std::int32_t>(petps::RpcStatus::kOk);
1004 std::size_t response_bytes = 0;
1005 const float* result_payload =
1006 BorrowPrefetchResult(state, &status_code, &response_bytes);
1007 if (result_payload == nullptr) {
1008 WaitForPrefetch(prefetch_id);
1009 const auto* status_word = petps::FixedSlotStatusWord(
1010 state.buffer,
1011 static_cast<std::size_t>(state.key_count),
1012 FLAGS_value_size);
1013 status_code = *status_word;
1014 response_bytes = static_cast<std::size_t>(state.key_count) *
1015 static_cast<std::size_t>(FLAGS_value_size);
1016 result_payload = state.buffer;
1017 }
1018 const auto wait_end = std::chrono::steady_clock::now();
1019 if (status_code != static_cast<std::int32_t>(petps::RpcStatus::kOk)) {
1020 RevokeRPCResource(state.rpc_id);
1021 MarkPrefetchConsumed(prefetch_id);
1022 return false;
1023 }
1024
1025 const std::size_t value_count =
1026 static_cast<std::size_t>(state.key_count) *
1027 static_cast<std::size_t>(state.embedding_dim);
1028 const auto assign_begin = std::chrono::steady_clock::now();
1029 if (FLAGS_rdma_adapter_skip_prefetch_result_copy) {
1030 values->clear();
1031 } else if (response_bytes == 0) {
1032 values->clear();
1033 } else {
1034 values->resize(value_count);
1035 if (value_count > 0) {
1036 const std::size_t expected_bytes = value_count * sizeof(values->front());
1037 if (response_bytes < expected_bytes) {
1038 RevokeRPCResource(state.rpc_id);
1039 MarkPrefetchConsumed(prefetch_id);
1040 return false;
1041 }
1042 std::memcpy(values->data(), result_payload, expected_bytes);
1043 }
1044 }
1045 const auto assign_end = std::chrono::steady_clock::now();
1046 *num_rows = state.key_count;
1047 const auto revoke_begin = std::chrono::steady_clock::now();
1048 RevokeRPCResource(state.rpc_id);
1049 MarkPrefetchConsumed(prefetch_id);
1050 const auto revoke_end = std::chrono::steady_clock::now();
1051 if (profile_enabled) {
1052 static std::atomic<std::uint64_t> count{0};
1053 static std::atomic<std::uint64_t> wait_ns{0};
1054 static std::atomic<std::uint64_t> assign_ns{0};
1055 static std::atomic<std::uint64_t> revoke_ns{0};
1056 const std::uint64_t current = count.fetch_add(1) + 1;
1057 wait_ns.fetch_add(
1058 static_cast<std::uint64_t>(NsSince(wait_begin, wait_end)));
1059 assign_ns.fetch_add(
1060 static_cast<std::uint64_t>(NsSince(assign_begin, assign_end)));
1061 revoke_ns.fetch_add(
1062 static_cast<std::uint64_t>(NsSince(revoke_begin, revoke_end)));
1063 if (current == 1 || current % 512 == 0) {
1064 const double denom = static_cast<double>(current);
1065 std::cout
1066 << "component=rdma_adapter_prefetch_profile"
1067 << " batches=" << current
1068 << " wait_avg_ns=" << static_cast<double>(wait_ns.load()) / denom
1069 << " assign_avg_ns=" << static_cast<double>(assign_ns.load()) / denom
1070 << " revoke_avg_ns=" << static_cast<double>(revoke_ns.load()) / denom
1071 << " value_count=" << value_count << std::endl;
1072 }
1073 }
1074 return true;
1075 }
1076
1077 } // namespace recstore
1078