GCC Code Coverage Report


Directory: src/
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 0.0% 0 / 0 / 500
Functions: 0.0% 0 / 0 / 34
Branches: 0.0% 0 / 0 / 602

ps/rdma/petps_client.cc
Line Branch Exec Source
1 #include "ps/rdma/petps_client.h"
2
3 #include <algorithm>
4 #include <cstdlib>
5 #include <cstring>
6 #include <iostream>
7 #include <stdexcept>
8 #include <thread>
9
10 #include <folly/portability/GFlags.h>
11
12 #include "ps/rdma/control_plane.h"
13 #include "ps/rdma/rdma_common.h"
14 #include "ps/rdma/rc_options.h"
15
16 DECLARE_int32(global_id);
17 DECLARE_int32(num_server_processes);
18 DECLARE_int32(num_client_processes);
19 DECLARE_int32(value_size);
20 DECLARE_int32(max_kv_num_per_request);
21 DEFINE_string(rdma_get_response_mode,
22 "direct_sg",
23 "RDMA GET response mode: direct_sg or staging_copy");
24
25 namespace petps {
26 namespace {
27
28 using petps::Exchange;
29 using petps::NamespaceToken;
30 using petps::NowNs;
31
32 std::size_t ComputeMaxGetKeysPerRpc() {
33 return GetKeysPerRpcByResponseBudget(
34 static_cast<std::size_t>(FLAGS_value_size),
35 static_cast<std::size_t>(FLAGS_rdma_rc_mtu_bytes),
36 static_cast<std::size_t>(FLAGS_rdma_rc_target_response_mtu));
37 }
38
39 std::int32_t WaitStatus(const StatusWord* status, std::uint64_t seq) {
40 const auto start = std::chrono::steady_clock::now();
41 int spin_iterations = 0;
42 while (!StatusWordDone(*status, seq)) {
43 if (spin_iterations < FLAGS_rdma_rc_wait_spin_iterations) {
44 ++spin_iterations;
45 } else {
46 spin_iterations = 0;
47 std::this_thread::yield();
48 }
49 if (FLAGS_rdma_wait_timeout_ms > 0) {
50 const auto elapsed_ms =
51 std::chrono::duration_cast<std::chrono::milliseconds>(
52 std::chrono::steady_clock::now() - start)
53 .count();
54 if (elapsed_ms > FLAGS_rdma_wait_timeout_ms) {
55 throw std::runtime_error("RC write RPC wait timeout");
56 }
57 }
58 }
59 return status->status;
60 }
61
62 void FillBaseDescriptor(
63 RequestDescriptor* descriptor,
64 std::uint64_t seq,
65 std::size_t key_count,
66 const RcClientQpView& view,
67 std::uint32_t shard_id,
68 std::uint32_t client_id) {
69 *descriptor = RequestDescriptor{};
70 descriptor->seq = seq;
71 descriptor->shard_id = shard_id;
72 descriptor->client_id = client_id;
73 descriptor->qp_index = static_cast<std::uint32_t>(view.qp_index);
74 descriptor->key_count = static_cast<std::uint32_t>(key_count);
75 descriptor->value_size = static_cast<std::uint32_t>(FLAGS_value_size);
76 descriptor->embedding_dim =
77 static_cast<std::uint32_t>(FLAGS_value_size / sizeof(float));
78 descriptor->payload_offset =
79 static_cast<std::uint32_t>(Align64(sizeof(RequestDescriptor)));
80 descriptor->client_response_addr =
81 reinterpret_cast<std::uint64_t>(view.response_payload);
82 descriptor->client_status_addr = reinterpret_cast<std::uint64_t>(view.status);
83 }
84
85 } // namespace
86
87 PetPSClient::PetPSClient(const std::string& host, int port, int shard)
88 : PetPSClient(host, port, shard, -1) {}
89
90 PetPSClient::PetPSClient(
91 const std::string& host, int port, int shard, int logical_client_id)
92 : BaseParameterClient(host, port, shard),
93 namespace_token_(NamespaceToken()),
94 explicit_client_id_(logical_client_id) {}
95
96 PetPSClient::~PetPSClient() = default;
97
98 void PetPSClient::Barrier(const std::string&, int) {}
99
100 void PetPSClient::InitializeTransport() {
101 if (transport_ != nullptr) {
102 return;
103 }
104 client_id_ =
105 explicit_client_id_ >= 0
106 ? explicit_client_id_
107 : (FLAGS_rdma_rc_client_id_base >= 0
108 ? FLAGS_rdma_rc_client_id_base
109 : FLAGS_global_id - FLAGS_num_server_processes);
110 if (client_id_ < 0) {
111 throw std::runtime_error("invalid RC write logical client_id");
112 }
113 const int logical_num_clients =
114 FLAGS_rdma_rc_num_logical_clients >= 0
115 ? FLAGS_rdma_rc_num_logical_clients
116 : FLAGS_num_client_processes;
117 if (client_id_ >= logical_num_clients) {
118 throw std::runtime_error("RC write logical client_id out of range");
119 }
120 config_.shard_id = shard_;
121 config_.client_id = client_id_;
122 config_.num_clients = logical_num_clients;
123 config_.qps_per_client_per_shard = FLAGS_rdma_rc_qps_per_client_per_shard;
124 config_.slots_per_qp = FLAGS_rdma_rc_slots_per_qp;
125 config_.request_slot_bytes =
126 static_cast<std::size_t>(FLAGS_rdma_rc_request_slot_bytes);
127 config_.response_slot_bytes =
128 static_cast<std::size_t>(FLAGS_rdma_rc_response_slot_bytes);
129 config_.control_plane_host = FLAGS_rdma_control_plane_host;
130 config_.control_plane_port = FLAGS_rdma_control_plane_port;
131 config_.control_plane_timeout_ms = FLAGS_rdma_control_plane_timeout_ms;
132 config_.namespace_token = namespace_token_;
133
134 transport_ = std::make_unique<RcShardClientTransport>(config_);
135 RdmaControlPlaneClient control_plane({
136 config_.control_plane_host,
137 config_.control_plane_port,
138 config_.control_plane_timeout_ms,
139 });
140 control_plane.WaitServer(shard_, config_.control_plane_timeout_ms);
141 qps_.clear();
142 qps_.reserve(static_cast<std::size_t>(config_.qps_per_client_per_shard));
143 for (int qp = 0; qp < config_.qps_per_client_per_shard; ++qp) {
144 QpContext context;
145 context.qp_index = qp;
146 context.slots.reserve(static_cast<std::size_t>(config_.slots_per_qp));
147 for (int slot_in_qp = 0; slot_in_qp < config_.slots_per_qp; ++slot_in_qp) {
148 context.slots.push_back(
149 SlotContext{transport_->OpenSlot(qp, slot_in_qp), 1, false});
150 }
151 qps_.push_back(std::move(context));
152 }
153 }
154
155 void PetPSClient::InitThread() {
156 std::lock_guard<std::mutex> guard(mu_);
157 InitializeTransport();
158 thread_initialized_ = true;
159 }
160
161 std::size_t PetPSClient::ResponseBufferBytes(std::size_t key_count) const {
162 return GetResponseBytes(
163 key_count, static_cast<std::size_t>(FLAGS_value_size)) +
164 sizeof(std::int32_t);
165 }
166
167 void* PetPSClient::GetReceiveBuffer(size_t size) {
168 std::lock_guard<std::mutex> guard(mu_);
169 receive_buffers_.emplace_back(size, 0);
170 return receive_buffers_.back().data();
171 }
172
173 const float* PetPSClient::BorrowGetResultPayload(
174 int rpc_id,
175 std::size_t* key_count,
176 std::size_t* response_bytes,
177 std::int32_t* status_code) {
178 PendingRpc pending;
179 {
180 std::lock_guard<std::mutex> guard(mu_);
181 if (!PendingRpcLocked(rpc_id, &pending)) {
182 return nullptr;
183 }
184 }
185
186 auto& slot = SlotAt(pending.qp_index, pending.slot_in_qp);
187 const bool profile_enabled = FLAGS_rdma_rc_profile_interval_ms > 0;
188 const std::uint64_t wait_start_ns = profile_enabled ? NowNs() : 0;
189 const std::int32_t rc_status = WaitStatus(slot.view.status, pending.seq);
190 if (profile_enabled) {
191 profile_.wait_rpc_count.fetch_add(1, std::memory_order_relaxed);
192 profile_.wait_status_ns.fetch_add(
193 NowNs() - wait_start_ns, std::memory_order_relaxed);
194 }
195
196 const std::size_t actual_response_bytes = std::min<std::size_t>(
197 slot.view.status->response_bytes, pending.response_bytes);
198 if (key_count != nullptr) {
199 *key_count = pending.key_count;
200 }
201 if (response_bytes != nullptr) {
202 *response_bytes = actual_response_bytes;
203 }
204 if (status_code != nullptr) {
205 *status_code = rc_status;
206 }
207 if (pending.recv_buffer != nullptr) {
208 auto* user_status = FixedSlotStatusWord(
209 pending.recv_buffer, pending.key_count, FLAGS_value_size);
210 *user_status = rc_status;
211 }
212 MaybeReportProfile();
213 return reinterpret_cast<const float*>(slot.view.response_payload);
214 }
215
216 PetPSClient::SlotHandle PetPSClient::AcquireIdleSlot() {
217 if (FLAGS_rdma_rc_profile_interval_ms > 0) {
218 profile_.acquire_qp_count.fetch_add(1, std::memory_order_relaxed);
219 }
220 for (std::size_t qp_index = 0; qp_index < qps_.size(); ++qp_index) {
221 auto& qp = qps_[qp_index];
222 for (std::size_t slot_in_qp = 0; slot_in_qp < qp.slots.size();
223 ++slot_in_qp) {
224 if (!qp.slots[slot_in_qp].busy) {
225 qp.slots[slot_in_qp].busy = true;
226 return SlotHandle{
227 static_cast<int>(qp_index),
228 static_cast<int>(slot_in_qp),
229 };
230 }
231 }
232 }
233 if (FLAGS_rdma_rc_profile_interval_ms > 0) {
234 profile_.acquire_qp_failures.fetch_add(1, std::memory_order_relaxed);
235 }
236 throw std::runtime_error("no idle RC write slot available");
237 }
238
239 PetPSClient::SlotContext& PetPSClient::SlotAt(int qp_index, int slot_in_qp) {
240 auto& qp = qps_.at(static_cast<std::size_t>(qp_index));
241 return qp.slots.at(static_cast<std::size_t>(slot_in_qp));
242 }
243
244 const PetPSClient::SlotContext&
245 PetPSClient::SlotAt(int qp_index, int slot_in_qp) const {
246 const auto& qp = qps_.at(static_cast<std::size_t>(qp_index));
247 return qp.slots.at(static_cast<std::size_t>(slot_in_qp));
248 }
249
250 void PetPSClient::EnsureThreadInitializedLocked() const {
251 if (!thread_initialized_) {
252 throw std::runtime_error("PetPSClient::InitThread must be called first");
253 }
254 }
255
256 bool PetPSClient::PendingRpcLocked(int rpc_id, PendingRpc* pending) const {
257 const auto it = pending_rpcs_.find(rpc_id);
258 if (it == pending_rpcs_.end()) {
259 return false;
260 }
261 if (pending != nullptr) {
262 *pending = it->second;
263 }
264 return true;
265 }
266
267 bool PetPSClient::RequestPayloadFitsSlot(std::size_t payload_bytes) const {
268 return Align64(sizeof(RequestDescriptor)) + payload_bytes +
269 Align64(sizeof(CommitWord)) <=
270 config_.request_slot_bytes;
271 }
272
273 float* PetPSClient::AllocateStatusReceiveBufferLocked() {
274 receive_buffers_.emplace_back(sizeof(std::int32_t), 0);
275 return reinterpret_cast<float*>(receive_buffers_.back().data());
276 }
277
278 void PetPSClient::MaybeReportProfile() {
279 if (FLAGS_rdma_rc_profile_interval_ms <= 0) {
280 return;
281 }
282 const std::uint64_t now = NowNs();
283 const std::uint64_t interval =
284 static_cast<std::uint64_t>(FLAGS_rdma_rc_profile_interval_ms) * 1000000;
285 std::uint64_t expected =
286 profile_.next_report_ns.load(std::memory_order_relaxed);
287 if (expected == 0) {
288 profile_.next_report_ns.compare_exchange_strong(
289 expected, now + interval, std::memory_order_relaxed);
290 return;
291 }
292 if (now < expected ||
293 !profile_.next_report_ns.compare_exchange_strong(
294 expected, now + interval, std::memory_order_relaxed)) {
295 return;
296 }
297
298 const std::uint64_t submit_count = Exchange(&profile_.submit_rpc_count);
299 const std::uint64_t wait_count = Exchange(&profile_.wait_rpc_count);
300 const std::uint64_t revoke_count = Exchange(&profile_.revoke_rpc_count);
301 const std::uint64_t submit_ns = Exchange(&profile_.submit_request_ns);
302 const std::uint64_t wait_ns = Exchange(&profile_.wait_status_ns);
303 const std::uint64_t copy_ns = Exchange(&profile_.copy_response_ns);
304 const std::uint64_t revoke_ns = Exchange(&profile_.revoke_resource_ns);
305 const std::uint64_t pending_samples = Exchange(&profile_.pending_rpc_samples);
306 const std::uint64_t pending_sum = Exchange(&profile_.pending_rpc_sum);
307 std::cout
308 << "component=rdma_rc_client_profile"
309 << " shard=" << shard_ << " client_id=" << client_id_
310 << " submit_count=" << submit_count << " wait_count=" << wait_count
311 << " revoke_count=" << revoke_count
312 << " acquire_qp_count=" << Exchange(&profile_.acquire_qp_count)
313 << " acquire_qp_failures=" << Exchange(&profile_.acquire_qp_failures)
314 << " submit_avg_ns=" << (submit_count == 0 ? 0 : submit_ns / submit_count)
315 << " wait_status_avg_ns=" << (wait_count == 0 ? 0 : wait_ns / wait_count)
316 << " copy_response_avg_ns="
317 << (wait_count == 0 ? 0 : copy_ns / wait_count)
318 << " copied_bytes=" << Exchange(&profile_.response_bytes_copied)
319 << " revoke_avg_ns=" << (revoke_count == 0 ? 0 : revoke_ns / revoke_count)
320 << " pending_rpc_peak=" << Exchange(&profile_.pending_rpc_peak)
321 << " pending_rpc_avg="
322 << (pending_samples == 0 ? 0 : pending_sum / pending_samples)
323 << " pending_rpc_last="
324 << profile_.pending_rpc_last.load(std::memory_order_relaxed) << std::endl;
325 }
326
327 void PetPSClient::FillGetDescriptor(
328 RequestDescriptor* descriptor,
329 std::uint64_t seq,
330 std::size_t key_count,
331 std::size_t response_bytes,
332 const RcClientQpView& view) const {
333 FillBaseDescriptor(
334 descriptor,
335 seq,
336 key_count,
337 view,
338 static_cast<std::uint32_t>(shard_),
339 static_cast<std::uint32_t>(client_id_));
340 descriptor->op = static_cast<std::uint16_t>(RcOp::kGet);
341 descriptor->payload_bytes =
342 static_cast<std::uint32_t>(GetRequestBytes(key_count));
343 descriptor->response_bytes = static_cast<std::uint32_t>(response_bytes);
344 if (FLAGS_rdma_get_response_mode == "direct_sg") {
345 descriptor->flags |= kRcFlagGetDirectSg | kRcFlagGetAllowFallbackCopy;
346 } else if (FLAGS_rdma_get_response_mode != "staging_copy") {
347 LOG(FATAL) << "unsupported --rdma_get_response_mode="
348 << FLAGS_rdma_get_response_mode;
349 }
350 }
351
352 void PetPSClient::FillPutDescriptor(
353 RequestDescriptor* descriptor,
354 std::uint64_t seq,
355 std::size_t key_count,
356 std::size_t payload_bytes,
357 const RcClientQpView& view) const {
358 FillBaseDescriptor(
359 descriptor,
360 seq,
361 key_count,
362 view,
363 static_cast<std::uint32_t>(shard_),
364 static_cast<std::uint32_t>(client_id_));
365 descriptor->op = static_cast<std::uint16_t>(RcOp::kPut);
366 descriptor->payload_bytes = static_cast<std::uint32_t>(payload_bytes);
367 descriptor->response_bytes = 0;
368 }
369
370 void PetPSClient::FillUpdateDescriptor(
371 RequestDescriptor* descriptor,
372 std::uint64_t seq,
373 std::size_t key_count,
374 std::size_t payload_bytes,
375 const std::string& table_name,
376 const RcClientQpView& view) const {
377 FillPutDescriptor(descriptor, seq, key_count, payload_bytes, view);
378 descriptor->op = static_cast<std::uint16_t>(RcOp::kUpdate);
379 if (!CopyTableName(table_name, &descriptor->table_name)) {
380 throw std::runtime_error("UPDATE table name too long");
381 }
382 }
383
384 void PetPSClient::FillInitTableDescriptor(
385 RequestDescriptor* descriptor,
386 std::uint64_t seq,
387 const std::string& table_name,
388 const RcClientQpView& view) const {
389 FillPutDescriptor(
390 descriptor, seq, /*key_count=*/0, InitTablePayloadBytes(), view);
391 descriptor->op = static_cast<std::uint16_t>(RcOp::kInitTable);
392 if (!CopyTableName(table_name, &descriptor->table_name)) {
393 throw std::runtime_error("INIT table name too long");
394 }
395 }
396
397 int PetPSClient::SubmitRpcLocked(
398 SlotContext* slot,
399 const RequestDescriptor& descriptor,
400 const void* payload,
401 std::size_t payload_bytes,
402 float* recv_buffer,
403 std::size_t key_count,
404 std::size_t response_bytes,
405 bool is_async) {
406 if (slot == nullptr) {
407 throw std::runtime_error("slot context is null");
408 }
409 ResetStatusWord(slot->view.status, descriptor.seq);
410 const bool profile_enabled = FLAGS_rdma_rc_profile_interval_ms > 0;
411 const std::uint64_t submit_start_ns = profile_enabled ? NowNs() : 0;
412 transport_->SubmitRequest(slot->view, descriptor, payload, payload_bytes);
413 if (profile_enabled) {
414 profile_.submit_rpc_count.fetch_add(1, std::memory_order_relaxed);
415 profile_.submit_request_ns.fetch_add(
416 NowNs() - submit_start_ns, std::memory_order_relaxed);
417 }
418 VLOG(1) << "component=rdma_rc_client event=submit shard=" << shard_
419 << " client_id=" << client_id_ << " qp=" << slot->view.qp_index
420 << " slot=" << slot->view.slot_index << " seq=" << descriptor.seq
421 << " op=" << descriptor.op << " key_count=" << key_count
422 << " payload_bytes=" << payload_bytes
423 << " response_bytes=" << response_bytes;
424
425 const int rpc_id = next_rpc_id_.fetch_add(1);
426 pending_rpcs_.emplace(
427 rpc_id,
428 PendingRpc{
429 slot->view.qp_index,
430 slot->view.slot_in_qp,
431 slot->view.slot_index,
432 descriptor.seq,
433 recv_buffer,
434 key_count,
435 response_bytes,
436 });
437 if (profile_enabled) {
438 const std::uint64_t pending_size = pending_rpcs_.size();
439 profile_.pending_rpc_samples.fetch_add(1, std::memory_order_relaxed);
440 profile_.pending_rpc_sum.fetch_add(pending_size, std::memory_order_relaxed);
441 profile_.pending_rpc_last.store(pending_size, std::memory_order_relaxed);
442 std::uint64_t peak =
443 profile_.pending_rpc_peak.load(std::memory_order_relaxed);
444 while (pending_size > peak &&
445 !profile_.pending_rpc_peak.compare_exchange_weak(
446 peak, pending_size, std::memory_order_relaxed)) {
447 }
448 MaybeReportProfile();
449 }
450 if (!is_async) {
451 WaitRPCFinish(rpc_id);
452 }
453 return rpc_id;
454 }
455
456 int PetPSClient::GetParameter(base::ConstArray<uint64_t> keys,
457 std::vector<std::vector<float>>* values) {
458 values->clear();
459 if (keys.Size() == 0) {
460 return 0;
461 }
462 const int embedding_dim = FLAGS_value_size / sizeof(float);
463 std::vector<float> flat(keys.Size() * embedding_dim + 1, 0.0f);
464 const int rpc_id = GetParameter(keys, flat.data(), false, 0);
465 const auto* status =
466 FixedSlotStatusWord(flat.data(), keys.Size(), FLAGS_value_size);
467 if (*status != static_cast<std::int32_t>(RpcStatus::kOk)) {
468 RevokeRPCResource(rpc_id);
469 return -1;
470 }
471 CopyFlatRowsToVectors(
472 flat.data(),
473 keys.Size(),
474 static_cast<std::size_t>(embedding_dim),
475 values);
476 RevokeRPCResource(rpc_id);
477 return 0;
478 }
479
480 int PetPSClient::GetParameter(
481 base::ConstArray<uint64_t> keys, float* values, bool isAsync, int) {
482 if (keys.Size() == 0) {
483 auto* status =
484 reinterpret_cast<std::int32_t*>(reinterpret_cast<char*>(values));
485 *status = static_cast<std::int32_t>(RpcStatus::kOk);
486 return 0;
487 }
488 int rpc_id = 0;
489 {
490 std::lock_guard<std::mutex> guard(mu_);
491 EnsureThreadInitializedLocked();
492 if (keys.Size() > ComputeMaxGetKeysPerRpc()) {
493 throw std::runtime_error(
494 "single-shard GET batch exceeds RC response budget");
495 }
496
497 const SlotHandle slot_handle = AcquireIdleSlot();
498 auto& slot = SlotAt(slot_handle.qp_index, slot_handle.slot_in_qp);
499 RequestDescriptor descriptor;
500 const std::size_t response_bytes = GetResponseBytes(
501 keys.Size(), static_cast<std::size_t>(FLAGS_value_size));
502 FillGetDescriptor(
503 &descriptor, slot.next_seq++, keys.Size(), response_bytes, slot.view);
504 if (descriptor.payload_bytes >
505 PutPayloadBudget(config_.request_slot_bytes)) {
506 slot.busy = false;
507 throw std::runtime_error("GET request exceeds RC request slot");
508 }
509 rpc_id = SubmitRpcLocked(
510 &slot,
511 descriptor,
512 keys.Data(),
513 descriptor.payload_bytes,
514 values,
515 keys.Size(),
516 response_bytes,
517 true);
518 }
519 if (!isAsync) {
520 WaitRPCFinish(rpc_id);
521 }
522 return rpc_id;
523 }
524
525 bool PetPSClient::QueryRPCFinished(int rpc_id) {
526 std::lock_guard<std::mutex> guard(mu_);
527 PendingRpc pending;
528 if (!PendingRpcLocked(rpc_id, &pending)) {
529 return true;
530 }
531 const auto& slot = SlotAt(pending.qp_index, pending.slot_in_qp);
532 return StatusWordDone(*slot.view.status, pending.seq);
533 }
534
535 void PetPSClient::WaitRPCFinish(int rpc_id) {
536 PendingRpc pending;
537 {
538 std::lock_guard<std::mutex> guard(mu_);
539 if (!PendingRpcLocked(rpc_id, &pending)) {
540 return;
541 }
542 }
543
544 auto& slot = SlotAt(pending.qp_index, pending.slot_in_qp);
545 const bool profile_enabled = FLAGS_rdma_rc_profile_interval_ms > 0;
546 const std::uint64_t wait_start_ns = profile_enabled ? NowNs() : 0;
547 const std::int32_t status_code = WaitStatus(slot.view.status, pending.seq);
548 if (profile_enabled) {
549 profile_.wait_rpc_count.fetch_add(1, std::memory_order_relaxed);
550 profile_.wait_status_ns.fetch_add(
551 NowNs() - wait_start_ns, std::memory_order_relaxed);
552 }
553 VLOG(1) << "component=rdma_rc_client event=done shard=" << shard_
554 << " client_id=" << client_id_ << " qp=" << pending.qp_index
555 << " slot=" << pending.slot_index << " seq=" << pending.seq
556 << " status=" << status_code
557 << " response_bytes=" << pending.response_bytes;
558 const std::size_t actual_response_bytes = std::min<std::size_t>(
559 slot.view.status->response_bytes, pending.response_bytes);
560 if (actual_response_bytes > 0 && !FLAGS_rdma_rc_skip_client_copy) {
561 const std::uint64_t copy_start_ns = profile_enabled ? NowNs() : 0;
562 std::memcpy(
563 pending.recv_buffer, slot.view.response_payload, actual_response_bytes);
564 if (profile_enabled) {
565 profile_.copy_response_ns.fetch_add(
566 NowNs() - copy_start_ns, std::memory_order_relaxed);
567 profile_.response_bytes_copied.fetch_add(
568 actual_response_bytes, std::memory_order_relaxed);
569 }
570 }
571 auto* user_status = FixedSlotStatusWord(
572 pending.recv_buffer, pending.key_count, FLAGS_value_size);
573 *user_status = status_code;
574 MaybeReportProfile();
575 }
576
577 void PetPSClient::RevokeRPCResource(int rpc_id) {
578 std::lock_guard<std::mutex> guard(mu_);
579 const auto it = pending_rpcs_.find(rpc_id);
580 if (it == pending_rpcs_.end()) {
581 return;
582 }
583 const bool profile_enabled = FLAGS_rdma_rc_profile_interval_ms > 0;
584 const std::uint64_t revoke_start_ns = profile_enabled ? NowNs() : 0;
585 auto& slot = SlotAt(it->second.qp_index, it->second.slot_in_qp);
586 transport_->ClearRequestSlot(slot.view);
587 slot.busy = false;
588 pending_rpcs_.erase(it);
589 if (profile_enabled) {
590 const std::uint64_t pending_size = pending_rpcs_.size();
591 profile_.pending_rpc_samples.fetch_add(1, std::memory_order_relaxed);
592 profile_.pending_rpc_sum.fetch_add(pending_size, std::memory_order_relaxed);
593 profile_.pending_rpc_last.store(pending_size, std::memory_order_relaxed);
594 profile_.revoke_rpc_count.fetch_add(1, std::memory_order_relaxed);
595 profile_.revoke_resource_ns.fetch_add(
596 NowNs() - revoke_start_ns, std::memory_order_relaxed);
597 MaybeReportProfile();
598 }
599 }
600
601 int PetPSClient::PutParameter(const std::vector<uint64_t>& keys,
602 const std::vector<std::vector<float>>& values) {
603 if (keys.size() != values.size()) {
604 return -1;
605 }
606 if (keys.empty()) {
607 return 0;
608 }
609
610 std::size_t begin = 0;
611 while (begin < keys.size()) {
612 std::size_t end =
613 std::min(begin + static_cast<std::size_t>(FLAGS_max_kv_num_per_request),
614 keys.size());
615 std::vector<std::uint64_t> key_slice(
616 keys.begin() + begin, keys.begin() + end);
617 std::vector<std::vector<float>> value_slice(
618 values.begin() + begin, values.begin() + end);
619
620 std::string payload;
621 std::string error;
622 const std::size_t payload_bytes =
623 PutPayloadBytes(key_slice, value_slice, &payload, &error);
624 if (payload_bytes == 0 && !key_slice.empty()) {
625 throw std::runtime_error("RC PUT payload build failed: " + error);
626 }
627
628 float* recv = nullptr;
629 int rpc_id = 0;
630 {
631 std::lock_guard<std::mutex> guard(mu_);
632 EnsureThreadInitializedLocked();
633 const SlotHandle slot_handle = AcquireIdleSlot();
634 auto& slot = SlotAt(slot_handle.qp_index, slot_handle.slot_in_qp);
635 RequestDescriptor descriptor;
636 FillPutDescriptor(
637 &descriptor,
638 slot.next_seq++,
639 key_slice.size(),
640 payload_bytes,
641 slot.view);
642 if (!RequestPayloadFitsSlot(payload_bytes)) {
643 slot.busy = false;
644 throw std::runtime_error("PUT request exceeds RC request slot");
645 }
646 recv = AllocateStatusReceiveBufferLocked();
647 rpc_id = SubmitRpcLocked(
648 &slot, descriptor, payload.data(), payload_bytes, recv, 0, 0, true);
649 }
650 WaitRPCFinish(rpc_id);
651 const auto* status = reinterpret_cast<const std::int32_t*>(recv);
652 RevokeRPCResource(rpc_id);
653 if (*status != static_cast<std::int32_t>(RpcStatus::kOk)) {
654 return -1;
655 }
656 begin = end;
657 }
658
659 return 0;
660 }
661
662 int PetPSClient::InitEmbeddingTable(const std::string& table_name,
663 std::uint64_t num_embeddings,
664 std::uint64_t embedding_dim) {
665 const std::array<std::uint64_t, 2> payload_words = {
666 num_embeddings,
667 embedding_dim,
668 };
669
670 float* recv = nullptr;
671 int rpc_id = 0;
672 {
673 std::lock_guard<std::mutex> guard(mu_);
674 EnsureThreadInitializedLocked();
675 const SlotHandle slot_handle = AcquireIdleSlot();
676 auto& slot = SlotAt(slot_handle.qp_index, slot_handle.slot_in_qp);
677 RequestDescriptor descriptor;
678 FillInitTableDescriptor(
679 &descriptor, slot.next_seq++, table_name, slot.view);
680 if (!RequestPayloadFitsSlot(descriptor.payload_bytes)) {
681 slot.busy = false;
682 throw std::runtime_error("INIT request exceeds RC request slot");
683 }
684 recv = AllocateStatusReceiveBufferLocked();
685 rpc_id = SubmitRpcLocked(
686 &slot,
687 descriptor,
688 payload_words.data(),
689 descriptor.payload_bytes,
690 recv,
691 0,
692 0,
693 true);
694 }
695
696 WaitRPCFinish(rpc_id);
697 const auto* status = reinterpret_cast<const std::int32_t*>(recv);
698 RevokeRPCResource(rpc_id);
699 return (*status == static_cast<std::int32_t>(RpcStatus::kOk)) ? 0 : -1;
700 }
701
702 int PetPSClient::UpdateParameter(const std::string& table_name,
703 base::ConstArray<uint64_t> keys,
704 const std::vector<std::vector<float>>* grads) {
705 if (keys.Size() == 0) {
706 return 0;
707 }
708 if (grads == nullptr) {
709 return -1;
710 }
711 if (keys.Size() != grads->size()) {
712 return -1;
713 }
714
715 std::size_t begin = 0;
716 const std::size_t total_keys = static_cast<std::size_t>(keys.Size());
717 while (begin < total_keys) {
718 const std::size_t end =
719 std::min(begin + static_cast<std::size_t>(FLAGS_max_kv_num_per_request),
720 total_keys);
721 std::vector<std::uint64_t> key_slice(
722 keys.Data() + begin, keys.Data() + end);
723 std::vector<std::vector<float>> grad_slice(
724 grads->begin() + begin, grads->begin() + end);
725
726 std::string payload;
727 std::string error;
728 const std::size_t payload_bytes =
729 UpdatePayloadBytes(key_slice, grad_slice, &payload, &error);
730 if (payload_bytes == 0 && !key_slice.empty()) {
731 throw std::runtime_error("RC UPDATE payload build failed: " + error);
732 }
733
734 float* recv = nullptr;
735 int rpc_id = 0;
736 {
737 std::lock_guard<std::mutex> guard(mu_);
738 EnsureThreadInitializedLocked();
739
740 const SlotHandle slot_handle = AcquireIdleSlot();
741 auto& slot = SlotAt(slot_handle.qp_index, slot_handle.slot_in_qp);
742 RequestDescriptor descriptor;
743 FillUpdateDescriptor(
744 &descriptor,
745 slot.next_seq++,
746 key_slice.size(),
747 payload_bytes,
748 table_name,
749 slot.view);
750 if (!RequestPayloadFitsSlot(payload_bytes)) {
751 slot.busy = false;
752 throw std::runtime_error("UPDATE request exceeds RC request slot");
753 }
754 recv = AllocateStatusReceiveBufferLocked();
755 rpc_id = SubmitRpcLocked(
756 &slot, descriptor, payload.data(), payload_bytes, recv, 0, 0, true);
757 }
758
759 WaitRPCFinish(rpc_id);
760 const auto* status = reinterpret_cast<const std::int32_t*>(recv);
761 RevokeRPCResource(rpc_id);
762 if (*status != static_cast<std::int32_t>(RpcStatus::kOk)) {
763 return -1;
764 }
765 begin = end;
766 }
767
768 return 0;
769 }
770
771 int PetPSClient::FakePutParameter(base::ConstArray<uint64_t> keys,
772 float* values) {
773 const int embedding_dim = FLAGS_value_size / sizeof(float);
774 std::vector<std::vector<float>> rows;
775 rows.reserve(keys.Size());
776 for (int i = 0; i < keys.Size(); ++i) {
777 rows.emplace_back(
778 values + i * embedding_dim, values + (i + 1) * embedding_dim);
779 }
780 return PutParameter(keys.ToVector(), rows);
781 }
782
783 } // namespace petps
784