GCC Code Coverage Report


Directory: src/
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 0.0% 0 / 0 / 761
Functions: 0.0% 0 / 0 / 38
Branches: 0.0% 0 / 0 / 1276

ps/rdma/petps_server.cc
Line Branch Exec Source
1 #include <folly/init/Init.h>
2
3 #include <boost/coroutine2/all.hpp>
4
5 #include <atomic>
6 #include <array>
7 #include <algorithm>
8 #include <chrono>
9 #include <condition_variable>
10 #include <deque>
11 #include <cstdint>
12 #include <cstdlib>
13 #include <cstring>
14 #include <fstream>
15 #include <iostream>
16 #include <limits>
17 #include <memory>
18 #include <mutex>
19 #include <stdexcept>
20 #include <string>
21 #include <thread>
22 #include <vector>
23
24 #include "base/bind_core.h"
25 #include "base/config.h"
26 #include "base/log.h"
27 #include "base/timer.h"
28 #include "memory/shm_file.h"
29 #include "ps/rdma/rdma_common.h"
30 #include "ps/base/cache_ps_impl.h"
31 #include "ps/rdma/control_plane.h"
32 #include "ps/rdma/rc_options.h"
33 #include "ps/rdma/rc_transport.h"
34 #include "ps/rdma/rdma_protocol.h"
35 #include "ps/rdma/rdma_status.h"
36
37 DEFINE_string(config_path, "", "config file path");
38 DEFINE_int32(thread_num, 1, "RC write poll thread count");
39 DECLARE_int32(global_id);
40 DECLARE_int32(num_server_processes);
41 DECLARE_int32(num_client_processes);
42 DEFINE_int32(value_size, 128, "embedding row bytes");
43 DEFINE_int32(max_kv_num_per_request, 500, "max keys per request");
44 DEFINE_bool(use_dram, false, "unused compatibility flag");
45 DEFINE_int32(numa_id, 0, "NUMA node id for mmap and core binding");
46
47 namespace {
48
49 using petps::Exchange;
50 using petps::NamespaceToken;
51 using petps::NowNs;
52
53 constexpr std::size_t kMaxDirectSgesPerWr = 32;
54
55 bool ShouldTraceRdmaGet() {
56 static const bool enabled = [] {
57 const char* env = std::getenv("RECSTORE_RDMA_GET_TRACE");
58 return env != nullptr && std::string(env) != "0";
59 }();
60 return enabled;
61 }
62
63 std::uint64_t RdmaGetTraceInterval() {
64 static const std::uint64_t interval = [] {
65 const char* env = std::getenv("RECSTORE_RDMA_GET_TRACE_INTERVAL");
66 if (env == nullptr) {
67 return std::uint64_t{5000};
68 }
69 const auto parsed =
70 static_cast<std::uint64_t>(std::strtoull(env, nullptr, 10));
71 return parsed == 0 ? std::uint64_t{5000} : parsed;
72 }();
73 return interval;
74 }
75
76 std::string TimestampNow() {
77 const auto now = std::chrono::system_clock::now().time_since_epoch();
78 return std::to_string(
79 std::chrono::duration_cast<std::chrono::microseconds>(now).count());
80 }
81
82 int ResolveShardId(const nlohmann::json& config) {
83 const int default_shard = FLAGS_global_id;
84 if (!config.contains("cache_ps") || !config["cache_ps"].is_object()) {
85 return default_shard;
86 }
87 const auto& cache_ps = config["cache_ps"];
88 if (cache_ps.contains("servers") && cache_ps["servers"].is_array()) {
89 for (const auto& server : cache_ps["servers"]) {
90 if (server.value("shard", -1) == FLAGS_global_id) {
91 return server.value("shard", default_shard);
92 }
93 }
94 }
95 return default_shard;
96 }
97
98 void NormalizeDramValuePath(nlohmann::json* base_kv_config) {
99 if (base_kv_config == nullptr || !base_kv_config->is_object()) {
100 return;
101 }
102 if (!base_kv_config->contains("value") ||
103 !(*base_kv_config)["value"].is_object()) {
104 return;
105 }
106 auto& value_cfg = (*base_kv_config)["value"];
107 const std::string value_type =
108 value_cfg.value("type", std::string("DRAM_VALUE_STORE"));
109 if (value_type != "DRAM_VALUE_STORE") {
110 return;
111 }
112 const std::string path = value_cfg.value("path", std::string());
113 if (path.empty() || path.rfind("/dev/shm", 0) == 0) {
114 return;
115 }
116 value_cfg["path"] = "/dev/shm/recstore_rdma_rc_" + TimestampNow() + "/value";
117 }
118
119 class PetPSServer {
120 public:
121 PetPSServer(CachePS* cache_ps,
122 int thread_count,
123 int shard_id,
124 const std::string& namespace_token)
125 : cache_ps_(cache_ps),
126 thread_count_(thread_count),
127 shard_id_(shard_id),
128 control_plane_client_(petps::RdmaControlPlaneEndpoint{
129 FLAGS_rdma_control_plane_host,
130 FLAGS_rdma_control_plane_port,
131 FLAGS_rdma_control_plane_timeout_ms,
132 }) {
133 petps::RcTransportConfig config;
134 config.shard_id = shard_id_;
135 config.num_clients =
136 FLAGS_rdma_rc_num_logical_clients >= 0
137 ? FLAGS_rdma_rc_num_logical_clients
138 : FLAGS_num_client_processes;
139 config.qps_per_client_per_shard = FLAGS_rdma_rc_qps_per_client_per_shard;
140 config.slots_per_qp = FLAGS_rdma_rc_slots_per_qp;
141 config.request_slot_bytes =
142 static_cast<std::size_t>(FLAGS_rdma_rc_request_slot_bytes);
143 config.response_slot_bytes =
144 static_cast<std::size_t>(FLAGS_rdma_rc_response_slot_bytes);
145 config.control_plane_host = FLAGS_rdma_control_plane_host;
146 config.control_plane_port = FLAGS_rdma_control_plane_port;
147 config.control_plane_timeout_ms = FLAGS_rdma_control_plane_timeout_ms;
148 config.namespace_token = namespace_token;
149 transport_ = std::make_unique<petps::RcShardServerTransport>(config);
150 const auto backing = cache_ps_->GetRDMABackingRegion();
151 if (backing.data != nullptr && backing.size > 0) {
152 transport_->RegisterLocalMemoryRegion(backing.data, backing.size);
153 LOG(INFO) << "component=rdma_rc_server event=value_region_registered"
154 << " bytes=" << backing.size;
155 } else {
156 LOG(INFO) << "component=rdma_rc_server event=value_region_unavailable";
157 }
158 last_seq_.assign(
159 static_cast<std::size_t>(transport_->TotalSlots()), std::uint64_t{0});
160 inflight_seq_.assign(
161 static_cast<std::size_t>(transport_->TotalSlots()), std::uint64_t{0});
162 get_payload_worker_count_ = FLAGS_rdma_rc_server_get_workers;
163 if (get_payload_worker_count_ < 0) {
164 LOG(FATAL) << "--rdma_rc_server_get_workers must be non-negative";
165 }
166 poller_profiles_.reserve(
167 static_cast<std::size_t>(std::max(1, thread_count_)));
168 for (int i = 0; i < std::max(1, thread_count_); ++i) {
169 poller_profiles_.emplace_back(std::make_unique<PollerProfile>());
170 }
171 get_payload_completions_.resize(
172 static_cast<std::size_t>(std::max(1, thread_count_)));
173 }
174
175 void Run() {
176 StartGetPayloadWorkers();
177 for (int i = 0; i < thread_count_; ++i) {
178 threads_.emplace_back(&PetPSServer::PollingThread, this, i);
179 }
180 }
181
182 private:
183 struct GetPayloadTask {
184 int slot = -1;
185 int client_id = -1;
186 int qp_index = -1;
187 int slot_in_qp = -1;
188 int poll_thread_id = -1;
189 std::uint64_t seq = 0;
190 petps::RequestDescriptor descriptor{};
191 const char* payload = nullptr;
192 petps::RcShardServerTransport::ResponseView response{};
193 };
194
195 struct GetPayloadCompletion {
196 int slot = -1;
197 int client_id = -1;
198 int qp_index = -1;
199 int slot_in_qp = -1;
200 int poll_thread_id = -1;
201 std::uint64_t seq = 0;
202 petps::RcShardServerTransport::ResponseView response{};
203 bool payload_written_direct = false;
204 };
205
206 struct ProfileCounters {
207 std::atomic<std::uint64_t> scan_rounds{0};
208 std::atomic<std::uint64_t> scanned_slots{0};
209 std::atomic<std::uint64_t> ready_slots{0};
210 std::atomic<std::uint64_t> not_ready_slots{0};
211 std::atomic<std::uint64_t> zero_seq_ready{0};
212 std::atomic<std::uint64_t> duplicate_seq_ready{0};
213 std::atomic<std::uint64_t> inflight_seq_ready{0};
214 std::atomic<std::uint64_t> empty_scan_rounds{0};
215 std::atomic<std::uint64_t> max_ready_per_round{0};
216 std::atomic<std::uint64_t> handled_get{0};
217 std::atomic<std::uint64_t> handled_put{0};
218 std::atomic<std::uint64_t> handled_update{0};
219 std::atomic<std::uint64_t> handled_init{0};
220 std::atomic<std::uint64_t> invalid_descriptor{0};
221 std::atomic<std::uint64_t> wrong_shard{0};
222 std::atomic<std::uint64_t> handle_get_ns{0};
223 std::atomic<std::uint64_t> get_batch_get_ns{0};
224 std::atomic<std::uint64_t> get_index_lookup_ns{0};
225 std::atomic<std::uint64_t> get_zero_fill_ns{0};
226 std::atomic<std::uint64_t> get_row_copy_ns{0};
227 std::atomic<std::uint64_t> get_rows{0};
228 std::atomic<std::uint64_t> get_value_bytes{0};
229 std::atomic<std::uint64_t> get_missing_rows{0};
230 std::atomic<std::uint64_t> get_direct_sg{0};
231 std::atomic<std::uint64_t> get_direct_sg_fallback{0};
232 std::atomic<std::uint64_t> get_direct_sg_ns{0};
233 std::atomic<std::uint64_t> get_direct_sg_wr{0};
234 std::atomic<std::uint64_t> handle_put_ns{0};
235 std::atomic<std::uint64_t> handle_update_ns{0};
236 std::atomic<std::uint64_t> handle_init_ns{0};
237 std::atomic<std::uint64_t> complete_response_ns{0};
238 std::atomic<std::uint64_t> poll_loop_ns{0};
239 std::atomic<std::uint64_t> next_report_ns{0};
240 };
241
242 struct PollerProfile {
243 std::atomic<std::uint64_t> scan_rounds{0};
244 std::atomic<std::uint64_t> scanned_slots{0};
245 std::atomic<std::uint64_t> ready_slots{0};
246 std::atomic<std::uint64_t> not_ready_slots{0};
247 std::atomic<std::uint64_t> duplicate_seq_ready{0};
248 std::atomic<std::uint64_t> inflight_seq_ready{0};
249 std::atomic<std::uint64_t> handled_get{0};
250 std::atomic<std::uint64_t> poll_loop_ns{0};
251 };
252
253 static void
254 UpdateMax(std::atomic<std::uint64_t>* value, std::uint64_t candidate) {
255 std::uint64_t current = value->load(std::memory_order_relaxed);
256 while (candidate > current &&
257 !value->compare_exchange_weak(
258 current, candidate, std::memory_order_relaxed)) {
259 }
260 }
261
262 void MaybeReportProfile(int thread_id) {
263 if (FLAGS_rdma_rc_profile_interval_ms <= 0 || thread_id != 0) {
264 return;
265 }
266 const std::uint64_t now = NowNs();
267 const std::uint64_t interval =
268 static_cast<std::uint64_t>(FLAGS_rdma_rc_profile_interval_ms) * 1000000;
269 std::uint64_t expected =
270 profile_.next_report_ns.load(std::memory_order_relaxed);
271 if (expected == 0) {
272 profile_.next_report_ns.compare_exchange_strong(
273 expected, now + interval, std::memory_order_relaxed);
274 return;
275 }
276 if (now < expected ||
277 !profile_.next_report_ns.compare_exchange_strong(
278 expected, now + interval, std::memory_order_relaxed)) {
279 return;
280 }
281
282 const std::uint64_t scan_rounds = Exchange(&profile_.scan_rounds);
283 const std::uint64_t scanned_slots = Exchange(&profile_.scanned_slots);
284 const std::uint64_t ready_slots = Exchange(&profile_.ready_slots);
285 const std::uint64_t not_ready_slots = Exchange(&profile_.not_ready_slots);
286 const std::uint64_t zero_seq_ready = Exchange(&profile_.zero_seq_ready);
287 const std::uint64_t duplicate_seq_ready =
288 Exchange(&profile_.duplicate_seq_ready);
289 const std::uint64_t inflight_seq_ready =
290 Exchange(&profile_.inflight_seq_ready);
291 const std::uint64_t empty_scan_rounds =
292 Exchange(&profile_.empty_scan_rounds);
293 const std::uint64_t max_ready_per_round =
294 Exchange(&profile_.max_ready_per_round);
295 const std::uint64_t handled_get = Exchange(&profile_.handled_get);
296 const std::uint64_t handled_put = Exchange(&profile_.handled_put);
297 const std::uint64_t handled_update = Exchange(&profile_.handled_update);
298 const std::uint64_t handled_init = Exchange(&profile_.handled_init);
299 const std::uint64_t complete_count =
300 handled_get + handled_put + handled_update + handled_init;
301 const std::uint64_t handle_get_ns = Exchange(&profile_.handle_get_ns);
302 const std::uint64_t get_batch_get_ns = Exchange(&profile_.get_batch_get_ns);
303 const std::uint64_t get_index_lookup_ns =
304 Exchange(&profile_.get_index_lookup_ns);
305 const std::uint64_t get_zero_fill_ns = Exchange(&profile_.get_zero_fill_ns);
306 const std::uint64_t get_row_copy_ns = Exchange(&profile_.get_row_copy_ns);
307 const std::uint64_t get_rows = Exchange(&profile_.get_rows);
308 const std::uint64_t get_value_bytes = Exchange(&profile_.get_value_bytes);
309 const std::uint64_t get_missing_rows = Exchange(&profile_.get_missing_rows);
310 const std::uint64_t get_direct_sg = Exchange(&profile_.get_direct_sg);
311 const std::uint64_t get_direct_sg_ns = Exchange(&profile_.get_direct_sg_ns);
312 const std::uint64_t handle_put_ns = Exchange(&profile_.handle_put_ns);
313 const std::uint64_t handle_update_ns = Exchange(&profile_.handle_update_ns);
314 const std::uint64_t handle_init_ns = Exchange(&profile_.handle_init_ns);
315 const std::uint64_t complete_response_ns =
316 Exchange(&profile_.complete_response_ns);
317 const std::uint64_t poll_loop_ns = Exchange(&profile_.poll_loop_ns);
318 std::uint64_t poller_min_get = std::numeric_limits<std::uint64_t>::max();
319 std::uint64_t poller_max_get = 0;
320 int poller_min_get_thread = -1;
321 int poller_max_get_thread = -1;
322 std::uint64_t poller_total_get = 0;
323 std::uint64_t poller_active = 0;
324 for (std::size_t i = 0; i < poller_profiles_.size(); ++i) {
325 auto& poller = *poller_profiles_[i];
326 const std::uint64_t poller_get = Exchange(&poller.handled_get);
327 const std::uint64_t poller_scan_rounds = Exchange(&poller.scan_rounds);
328 const std::uint64_t poller_scanned_slots =
329 Exchange(&poller.scanned_slots);
330 const std::uint64_t poller_ready_slots = Exchange(&poller.ready_slots);
331 const std::uint64_t poller_not_ready_slots =
332 Exchange(&poller.not_ready_slots);
333 const std::uint64_t poller_duplicate_seq_ready =
334 Exchange(&poller.duplicate_seq_ready);
335 const std::uint64_t poller_inflight_seq_ready =
336 Exchange(&poller.inflight_seq_ready);
337 const std::uint64_t poller_poll_loop_ns = Exchange(&poller.poll_loop_ns);
338 if (poller_get > 0) {
339 ++poller_active;
340 }
341 poller_total_get += poller_get;
342 if (poller_get < poller_min_get) {
343 poller_min_get = poller_get;
344 poller_min_get_thread = static_cast<int>(i);
345 }
346 if (poller_get > poller_max_get) {
347 poller_max_get = poller_get;
348 poller_max_get_thread = static_cast<int>(i);
349 }
350 std::cout
351 << "component=rdma_rc_server_poller_profile"
352 << " shard=" << shard_id_ << " thread_id=" << i << " scan_rounds="
353 << poller_scan_rounds << " scanned_slots=" << poller_scanned_slots
354 << " ready_slots=" << poller_ready_slots << " scan_hit_pct="
355 << (poller_scanned_slots == 0
356 ? 0.0
357 : 100.0 * static_cast<double>(poller_ready_slots) /
358 static_cast<double>(poller_scanned_slots))
359 << " not_ready_slots=" << poller_not_ready_slots
360 << " duplicate_seq_ready=" << poller_duplicate_seq_ready
361 << " inflight_seq_ready=" << poller_inflight_seq_ready
362 << " handled_get=" << poller_get << " poll_loop_avg_ns="
363 << (poller_scan_rounds == 0
364 ? 0
365 : poller_poll_loop_ns / poller_scan_rounds)
366 << std::endl;
367 }
368 if (poller_min_get == std::numeric_limits<std::uint64_t>::max()) {
369 poller_min_get = 0;
370 }
371 std::cout
372 << "component=rdma_rc_server_profile"
373 << " shard=" << shard_id_ << " threads=" << thread_count_
374 << " scan_rounds=" << scan_rounds << " scanned_slots=" << scanned_slots
375 << " ready_slots=" << ready_slots << " not_ready_slots="
376 << not_ready_slots << " zero_seq_ready=" << zero_seq_ready
377 << " duplicate_seq_ready=" << duplicate_seq_ready
378 << " inflight_seq_ready=" << inflight_seq_ready
379 << " empty_scan_rounds=" << empty_scan_rounds << " scan_hit_pct="
380 << (scanned_slots == 0 ? 0.0
381 : 100.0 * static_cast<double>(ready_slots) /
382 static_cast<double>(scanned_slots))
383 << " ready_round_pct="
384 << (scan_rounds == 0
385 ? 0.0
386 : 100.0 * static_cast<double>(scan_rounds - empty_scan_rounds) /
387 static_cast<double>(scan_rounds))
388 << " avg_ready_per_round="
389 << (scan_rounds == 0 ? 0.0
390 : static_cast<double>(ready_slots) /
391 static_cast<double>(scan_rounds))
392 << " max_ready_per_round=" << max_ready_per_round
393 << " handled_get=" << handled_get << " handled_put=" << handled_put
394 << " handled_update=" << handled_update
395 << " handled_init=" << handled_init
396 << " invalid_descriptor=" << Exchange(&profile_.invalid_descriptor)
397 << " wrong_shard=" << Exchange(&profile_.wrong_shard)
398 << " handle_get_avg_ns="
399 << (handled_get == 0 ? 0 : handle_get_ns / handled_get)
400 << " get_batch_get_avg_ns="
401 << (handled_get == 0 ? 0 : get_batch_get_ns / handled_get)
402 << " get_index_lookup_avg_ns="
403 << (handled_get == 0 ? 0 : get_index_lookup_ns / handled_get)
404 << " get_zero_fill_avg_ns="
405 << (handled_get == 0 ? 0 : get_zero_fill_ns / handled_get)
406 << " get_row_copy_avg_ns="
407 << (handled_get == 0 ? 0 : get_row_copy_ns / handled_get)
408 << " get_rows=" << get_rows << " get_value_bytes=" << get_value_bytes
409 << " get_missing_rows=" << get_missing_rows
410 << " get_direct_sg=" << get_direct_sg << " get_direct_sg_fallback="
411 << Exchange(&profile_.get_direct_sg_fallback)
412 << " get_direct_sg_avg_ns="
413 << (get_direct_sg == 0 ? 0 : get_direct_sg_ns / get_direct_sg)
414 << " get_direct_sg_wr=" << Exchange(&profile_.get_direct_sg_wr)
415 << " handle_put_avg_ns="
416 << (handled_put == 0 ? 0 : handle_put_ns / handled_put)
417 << " handle_update_avg_ns="
418 << (handled_update == 0 ? 0 : handle_update_ns / handled_update)
419 << " handle_init_avg_ns="
420 << (handled_init == 0 ? 0 : handle_init_ns / handled_init)
421 << " complete_response_avg_ns="
422 << (complete_count == 0 ? 0 : complete_response_ns / complete_count)
423 << " poll_loop_avg_ns="
424 << (scan_rounds == 0 ? 0 : poll_loop_ns / scan_rounds)
425 << " poller_active=" << poller_active << " poller_total_get="
426 << poller_total_get << " poller_min_get=" << poller_min_get
427 << " poller_min_get_thread=" << poller_min_get_thread
428 << " poller_max_get=" << poller_max_get
429 << " poller_max_get_thread=" << poller_max_get_thread << std::endl;
430 }
431
432 bool GetPayloadOffloadEnabled() const {
433 return get_payload_worker_count_ > 0;
434 }
435
436 std::size_t MaxGetPayloadQueueDepth() const {
437 return static_cast<std::size_t>(std::max(1, transport_->TotalSlots()));
438 }
439
440 void StartGetPayloadWorkers() {
441 if (!GetPayloadOffloadEnabled()) {
442 return;
443 }
444 for (int worker_id = 0; worker_id < get_payload_worker_count_;
445 ++worker_id) {
446 get_payload_workers_.emplace_back(
447 &PetPSServer::GetPayloadWorkerLoop, this, worker_id);
448 }
449 LOG(INFO) << "component=rdma_rc_server event=get_payload_workers_started"
450 << " count=" << get_payload_worker_count_;
451 }
452
453 void BindServerCore(int core_index) {
454 base::bind_core_with_env_offset(core_index);
455 }
456
457 bool EnqueueGetPayloadTask(const GetPayloadTask& task) {
458 std::lock_guard<std::mutex> guard(get_payload_mu_);
459 if (get_payload_tasks_.size() >= MaxGetPayloadQueueDepth()) {
460 return false;
461 }
462 get_payload_tasks_.push_back(task);
463 get_payload_cv_.notify_one();
464 return true;
465 }
466
467 std::size_t PollThreadIndex(int poll_thread_id) const {
468 return static_cast<std::size_t>(poll_thread_id);
469 }
470
471 bool TryPopGetPayloadCompletion(int poll_thread_id,
472 GetPayloadCompletion* completion) {
473 std::lock_guard<std::mutex> guard(get_payload_mu_);
474 auto& completions =
475 get_payload_completions_.at(PollThreadIndex(poll_thread_id));
476 if (completions.empty()) {
477 return false;
478 }
479 *completion = completions.front();
480 completions.pop_front();
481 return true;
482 }
483
484 void PushGetPayloadCompletion(const GetPayloadCompletion& completion) {
485 std::lock_guard<std::mutex> guard(get_payload_mu_);
486 get_payload_completions_.at(PollThreadIndex(completion.poll_thread_id))
487 .push_back(completion);
488 }
489
490 void AccumulateFlatGetProfile(const CachePS::FlatGetProfile& get_profile) {
491 profile_.get_batch_get_ns.fetch_add(
492 get_profile.batch_get_ns, std::memory_order_relaxed);
493 profile_.get_index_lookup_ns.fetch_add(
494 get_profile.index_lookup_ns, std::memory_order_relaxed);
495 profile_.get_zero_fill_ns.fetch_add(
496 get_profile.zero_fill_ns, std::memory_order_relaxed);
497 profile_.get_row_copy_ns.fetch_add(
498 get_profile.row_copy_ns, std::memory_order_relaxed);
499 profile_.get_rows.fetch_add(get_profile.rows, std::memory_order_relaxed);
500 profile_.get_value_bytes.fetch_add(
501 get_profile.value_bytes, std::memory_order_relaxed);
502 profile_.get_missing_rows.fetch_add(
503 get_profile.missing_rows, std::memory_order_relaxed);
504 }
505
506 void GetPayloadWorkerLoop(int worker_id) {
507 BindServerCore(thread_count_ + worker_id);
508 LOG(INFO) << "component=rdma_rc_server event=get_payload_worker_ready"
509 << " worker_id=" << worker_id;
510 while (true) {
511 GetPayloadTask task;
512 {
513 std::unique_lock<std::mutex> lock(get_payload_mu_);
514 get_payload_cv_.wait(lock, [this] {
515 return !get_payload_tasks_.empty();
516 });
517 task = get_payload_tasks_.front();
518 get_payload_tasks_.pop_front();
519 }
520
521 const bool profile_enabled = FLAGS_rdma_rc_profile_interval_ms > 0;
522 const std::uint64_t handle_start_ns = profile_enabled ? NowNs() : 0;
523 const bool payload_written_direct = HandleGet(
524 task.descriptor,
525 task.payload,
526 &task.response,
527 worker_id,
528 task.slot_in_qp);
529 if (profile_enabled) {
530 profile_.handled_get.fetch_add(1, std::memory_order_relaxed);
531 profile_.handle_get_ns.fetch_add(
532 NowNs() - handle_start_ns, std::memory_order_relaxed);
533 }
534 const GetPayloadCompletion completion{
535 task.slot,
536 task.client_id,
537 task.qp_index,
538 task.slot_in_qp,
539 task.poll_thread_id,
540 task.seq,
541 task.response,
542 payload_written_direct,
543 };
544 PushGetPayloadCompletion(completion);
545 }
546 }
547
548 void CompleteResponseForSlot(
549 int slot,
550 int client_id,
551 int qp_index,
552 int slot_in_qp,
553 const petps::RcShardServerTransport::ResponseView& response,
554 std::uint64_t seq,
555 bool profile_enabled) {
556 std::atomic_thread_fence(std::memory_order_release);
557 const std::uint64_t complete_start_ns = profile_enabled ? NowNs() : 0;
558 transport_->CompleteResponse(
559 client_id, qp_index, slot_in_qp, response, seq);
560 if (profile_enabled) {
561 profile_.complete_response_ns.fetch_add(
562 NowNs() - complete_start_ns, std::memory_order_relaxed);
563 }
564 VLOG(1) << "component=rdma_rc_server event=complete shard=" << shard_id_
565 << " slot=" << slot << " client_id=" << client_id
566 << " qp=" << qp_index << " seq=" << seq
567 << " status=" << response.status->status
568 << " response_bytes=" << response.status->response_bytes;
569 last_seq_[static_cast<std::size_t>(slot)] = seq;
570 if (GetPayloadOffloadEnabled()) {
571 inflight_seq_[static_cast<std::size_t>(slot)] = 0;
572 }
573 }
574
575 void CompleteResponseStatusOnlyForSlot(
576 int slot,
577 int client_id,
578 int qp_index,
579 int slot_in_qp,
580 const petps::RcShardServerTransport::ResponseView& response,
581 std::uint64_t seq,
582 bool profile_enabled) {
583 std::atomic_thread_fence(std::memory_order_release);
584 const std::uint64_t complete_start_ns = profile_enabled ? NowNs() : 0;
585 transport_->CompleteResponseStatusOnly(
586 client_id, qp_index, slot_in_qp, response, seq);
587 if (profile_enabled) {
588 profile_.complete_response_ns.fetch_add(
589 NowNs() - complete_start_ns, std::memory_order_relaxed);
590 }
591 VLOG(1) << "component=rdma_rc_server event=complete_direct shard="
592 << shard_id_ << " slot=" << slot << " client_id=" << client_id
593 << " qp=" << qp_index << " seq=" << seq
594 << " status=" << response.status->status
595 << " response_bytes=" << response.status->response_bytes;
596 last_seq_[static_cast<std::size_t>(slot)] = seq;
597 if (GetPayloadOffloadEnabled()) {
598 inflight_seq_[static_cast<std::size_t>(slot)] = 0;
599 }
600 }
601
602 void DrainGetPayloadCompletions(int poll_thread_id, bool profile_enabled) {
603 GetPayloadCompletion completion;
604 while (TryPopGetPayloadCompletion(poll_thread_id, &completion)) {
605 if (completion.payload_written_direct) {
606 CompleteResponseStatusOnlyForSlot(
607 completion.slot,
608 completion.client_id,
609 completion.qp_index,
610 completion.slot_in_qp,
611 completion.response,
612 completion.seq,
613 profile_enabled);
614 } else {
615 CompleteResponseForSlot(
616 completion.slot,
617 completion.client_id,
618 completion.qp_index,
619 completion.slot_in_qp,
620 completion.response,
621 completion.seq,
622 profile_enabled);
623 }
624 }
625 }
626
627 bool HandleGetDirectSg(
628 const petps::RequestDescriptor& descriptor,
629 base::ConstArray<std::uint64_t> keys,
630 petps::RcShardServerTransport::ResponseView* response,
631 int thread_id,
632 int slot_in_qp,
633 CachePS::FlatGetProfile* get_profile) {
634 if (descriptor.response_bytes == 0 || descriptor.embedding_dim == 0) {
635 return false;
636 }
637 const std::size_t row_bytes =
638 static_cast<std::size_t>(descriptor.embedding_dim) * sizeof(float);
639 if (row_bytes == 0 ||
640 descriptor.response_bytes !=
641 descriptor.key_count * static_cast<std::uint32_t>(row_bytes)) {
642 return false;
643 }
644
645 thread_local std::vector<CachePS::DirectFixedRow> rows;
646 rows.clear();
647 const std::uint64_t direct_start_ns =
648 FLAGS_rdma_rc_profile_interval_ms > 0 ? NowNs() : 0;
649 const bool ok = cache_ps_->GetParameterDirectFixedRows(
650 keys,
651 descriptor.key_count,
652 descriptor.embedding_dim,
653 thread_id,
654 &rows,
655 get_profile);
656 if (!ok || rows.size() != descriptor.key_count) {
657 return false;
658 }
659 std::uint64_t response_offset = 0;
660 std::uint64_t wr_count = 0;
661 for (std::size_t row = 0; row < rows.size();) {
662 std::array<petps::RawVerbsSge, kMaxDirectSgesPerWr> sges{};
663 std::size_t sge_count = 0;
664 std::size_t row_count = 0;
665 for (; row < rows.size(); ++row) {
666 const auto& ref = rows[row];
667 if (ref.missing || ref.data == nullptr || ref.size != row_bytes) {
668 return false;
669 }
670 if (sge_count > 0) {
671 auto& last = sges[sge_count - 1];
672 const char* last_end =
673 static_cast<const char*>(last.data) + last.bytes;
674 if (last_end == ref.data) {
675 last.bytes += row_bytes;
676 ++row_count;
677 continue;
678 }
679 }
680 if (sge_count == kMaxDirectSgesPerWr) {
681 break;
682 }
683 sges[sge_count++] = petps::RawVerbsSge{ref.data, row_bytes};
684 ++row_count;
685 }
686 const std::uint64_t bytes =
687 static_cast<std::uint64_t>(row_count * row_bytes);
688 transport_->WriteResponsePayloadSg(
689 descriptor.client_id,
690 descriptor.qp_index,
691 slot_in_qp,
692 base::ConstArray<petps::RawVerbsSge>(
693 sges.data(), static_cast<int>(sge_count)),
694 response_offset,
695 bytes);
696 response_offset += bytes;
697 ++wr_count;
698 }
699 response->status->status = static_cast<std::int32_t>(petps::RpcStatus::kOk);
700 response->status->response_bytes =
701 static_cast<std::uint32_t>(descriptor.response_bytes);
702 if (FLAGS_rdma_rc_profile_interval_ms > 0) {
703 profile_.get_direct_sg.fetch_add(1, std::memory_order_relaxed);
704 profile_.get_direct_sg_ns.fetch_add(
705 NowNs() - direct_start_ns, std::memory_order_relaxed);
706 profile_.get_direct_sg_wr.fetch_add(wr_count, std::memory_order_relaxed);
707 if (get_profile != nullptr) {
708 AccumulateFlatGetProfile(*get_profile);
709 }
710 }
711 return true;
712 }
713
714 bool HandleGet(const petps::RequestDescriptor& descriptor,
715 const char* payload,
716 petps::RcShardServerTransport::ResponseView* response,
717 int thread_id,
718 int slot_in_qp) {
719 if (FLAGS_rdma_rc_fake_get_mode == "status_only") {
720 response->status->status =
721 static_cast<std::int32_t>(petps::RpcStatus::kOk);
722 response->status->response_bytes = 0;
723 return false;
724 }
725 if (FLAGS_rdma_rc_fake_get_mode == "payload_memset") {
726 std::memset(response->payload, 0, descriptor.response_bytes);
727 response->status->status =
728 static_cast<std::int32_t>(petps::RpcStatus::kOk);
729 response->status->response_bytes =
730 static_cast<std::uint32_t>(descriptor.response_bytes);
731 return false;
732 }
733 if (FLAGS_rdma_rc_fake_get_mode == "index_only") {
734 base::ConstArray<std::uint64_t> keys(
735 reinterpret_cast<const std::uint64_t*>(payload),
736 descriptor.key_count);
737 CachePS::FlatGetProfile get_profile;
738 CachePS::FlatGetProfile* get_profile_ptr =
739 FLAGS_rdma_rc_profile_interval_ms > 0 ? &get_profile : nullptr;
740 const bool ok =
741 cache_ps_->ProbeParameterIndex(keys, thread_id, get_profile_ptr);
742 if (get_profile_ptr != nullptr) {
743 AccumulateFlatGetProfile(get_profile);
744 }
745 response->status->status = static_cast<std::int32_t>(
746 ok ? petps::RpcStatus::kOk : petps::RpcStatus::kValueSizeMismatch);
747 response->status->response_bytes = 0;
748 return false;
749 }
750 if (FLAGS_rdma_rc_fake_get_mode != "none" &&
751 !FLAGS_rdma_rc_fake_get_mode.empty()) {
752 response->status->status =
753 static_cast<std::int32_t>(petps::RpcStatus::kInvalidPayload);
754 response->status->response_bytes = 0;
755 return false;
756 }
757
758 base::ConstArray<std::uint64_t> keys(
759 reinterpret_cast<const std::uint64_t*>(payload), descriptor.key_count);
760 CachePS::FlatGetProfile get_profile;
761 CachePS::FlatGetProfile* get_profile_ptr =
762 FLAGS_rdma_rc_profile_interval_ms > 0 ? &get_profile : nullptr;
763 if ((descriptor.flags & petps::kRcFlagGetDirectSg) != 0) {
764 const bool direct_ok = HandleGetDirectSg(
765 descriptor, keys, response, thread_id, slot_in_qp, get_profile_ptr);
766 if (direct_ok) {
767 return true;
768 }
769 if (FLAGS_rdma_rc_profile_interval_ms > 0) {
770 profile_.get_direct_sg_fallback.fetch_add(1, std::memory_order_relaxed);
771 }
772 if ((descriptor.flags & petps::kRcFlagGetAllowFallbackCopy) == 0) {
773 response->status->status =
774 static_cast<std::int32_t>(petps::RpcStatus::kInvalidPayload);
775 response->status->response_bytes = 0;
776 return false;
777 }
778 }
779 const bool ok = cache_ps_->GetParameterFlat(
780 keys,
781 reinterpret_cast<float*>(response->payload),
782 descriptor.key_count,
783 descriptor.embedding_dim,
784 thread_id,
785 get_profile_ptr);
786 if (get_profile_ptr != nullptr) {
787 AccumulateFlatGetProfile(get_profile);
788 }
789 response->status->status = static_cast<std::int32_t>(
790 ok ? petps::RpcStatus::kOk : petps::RpcStatus::kValueSizeMismatch);
791 response->status->response_bytes =
792 static_cast<std::uint32_t>(descriptor.response_bytes);
793 return false;
794 }
795
796 void HandlePut(const petps::RequestDescriptor& descriptor,
797 const char* payload,
798 petps::RcShardServerTransport::ResponseView* response,
799 int thread_id) {
800 const auto* reader =
801 reinterpret_cast<const ParameterCompressReader*>(payload);
802 if (!reader->Valid(static_cast<int>(descriptor.payload_bytes))) {
803 response->status->status =
804 static_cast<std::int32_t>(petps::RpcStatus::kInvalidPayload);
805 response->status->response_bytes = 0;
806 return;
807 }
808 for (int i = 0; i < reader->item_size(); ++i) {
809 cache_ps_->PutSingleParameter(reader->item(i), thread_id);
810 }
811 response->status->status = static_cast<std::int32_t>(petps::RpcStatus::kOk);
812 response->status->response_bytes = 0;
813 }
814
815 void HandleUpdate(const petps::RequestDescriptor& descriptor,
816 const char* payload,
817 petps::RcShardServerTransport::ResponseView* response,
818 int thread_id) {
819 const std::string_view table_name = petps::DescriptorTableName(descriptor);
820 if (table_name.empty()) {
821 response->status->status =
822 static_cast<std::int32_t>(petps::RpcStatus::kInvalidPayload);
823 response->status->response_bytes = 0;
824 return;
825 }
826
827 const auto* reader =
828 reinterpret_cast<const ParameterCompressReader*>(payload);
829 if (!reader->Valid(static_cast<int>(descriptor.payload_bytes))) {
830 response->status->status =
831 static_cast<std::int32_t>(petps::RpcStatus::kInvalidPayload);
832 response->status->response_bytes = 0;
833 return;
834 }
835
836 const bool ok = cache_ps_->UpdateParameter(
837 std::string(table_name), reader, static_cast<unsigned>(thread_id));
838 response->status->status = static_cast<std::int32_t>(
839 ok ? petps::RpcStatus::kOk : petps::RpcStatus::kInvalidPayload);
840 response->status->response_bytes = 0;
841 }
842
843 void HandleInitTable(const petps::RequestDescriptor& descriptor,
844 const char* payload,
845 petps::RcShardServerTransport::ResponseView* response) {
846 const std::string_view table_name = petps::DescriptorTableName(descriptor);
847 if (table_name.empty() ||
848 descriptor.payload_bytes != petps::InitTablePayloadBytes()) {
849 response->status->status =
850 static_cast<std::int32_t>(petps::RpcStatus::kInvalidPayload);
851 response->status->response_bytes = 0;
852 return;
853 }
854
855 std::uint64_t num_embeddings = 0;
856 std::uint64_t embedding_dim = 0;
857 std::memcpy(&num_embeddings, payload, sizeof(num_embeddings));
858 std::memcpy(&embedding_dim,
859 payload + sizeof(num_embeddings),
860 sizeof(embedding_dim));
861 const bool ok = cache_ps_->InitTable(
862 std::string(table_name), num_embeddings, embedding_dim);
863 response->status->status = static_cast<std::int32_t>(
864 ok ? petps::RpcStatus::kOk : petps::RpcStatus::kInvalidPayload);
865 response->status->response_bytes = 0;
866 }
867
868 void MaybePublishServerReady() {
869 const int started =
870 started_threads_.fetch_add(1, std::memory_order_relaxed) + 1;
871 if (started != thread_count_ ||
872 ready_published_.exchange(true, std::memory_order_acq_rel)) {
873 return;
874 }
875 control_plane_client_.PublishServerReady(FLAGS_global_id);
876 LOG(INFO) << "component=rdma_control_plane event=server_ready_published"
877 << " server_id=" << FLAGS_global_id
878 << " host=" << FLAGS_rdma_control_plane_host
879 << " port=" << FLAGS_rdma_control_plane_port;
880 }
881
882 void PollingThread(int thread_id) {
883 BindServerCore(thread_id);
884 LOG(INFO) << "component=rdma_server event=polling_thread_ready thread_id="
885 << thread_id;
886 MaybePublishServerReady();
887 const int coroutines_per_thread =
888 std::max(1, FLAGS_rdma_rc_server_coroutines_per_thread);
889 LOG(INFO) << "component=rdma_rc_server event=polling_thread_mode"
890 << " thread_id=" << thread_id
891 << " coroutines_per_thread=" << coroutines_per_thread;
892 if (coroutines_per_thread > 1) {
893 RunCoroutinePollingThread(thread_id, coroutines_per_thread);
894 return;
895 }
896 while (true) {
897 const bool profile_enabled = FLAGS_rdma_rc_profile_interval_ms > 0;
898 const std::uint64_t poll_start_ns = profile_enabled ? NowNs() : 0;
899 std::uint64_t scanned_slots = 0;
900 std::uint64_t ready_slots = 0;
901 DrainGetPayloadCompletions(thread_id, profile_enabled);
902 ScanAssignedSlots(
903 thread_id,
904 /*worker_id=*/0,
905 /*worker_count=*/1,
906 profile_enabled,
907 &scanned_slots,
908 &ready_slots);
909 DrainGetPayloadCompletions(thread_id, profile_enabled);
910 if (profile_enabled) {
911 profile_.scan_rounds.fetch_add(1, std::memory_order_relaxed);
912 profile_.scanned_slots.fetch_add(
913 scanned_slots, std::memory_order_relaxed);
914 if (ready_slots == 0) {
915 profile_.empty_scan_rounds.fetch_add(1, std::memory_order_relaxed);
916 }
917 UpdateMax(&profile_.max_ready_per_round, ready_slots);
918 const std::uint64_t poll_loop_ns = NowNs() - poll_start_ns;
919 profile_.poll_loop_ns.fetch_add(
920 poll_loop_ns, std::memory_order_relaxed);
921 auto& poller =
922 *poller_profiles_.at(static_cast<std::size_t>(thread_id));
923 poller.scan_rounds.fetch_add(1, std::memory_order_relaxed);
924 poller.scanned_slots.fetch_add(
925 scanned_slots, std::memory_order_relaxed);
926 poller.ready_slots.fetch_add(ready_slots, std::memory_order_relaxed);
927 poller.poll_loop_ns.fetch_add(poll_loop_ns, std::memory_order_relaxed);
928 MaybeReportProfile(thread_id);
929 }
930 std::this_thread::yield();
931 }
932 }
933
934 bool ProcessSlot(int slot, int thread_id, bool profile_enabled) {
935 int client_id = -1;
936 int qp_index = -1;
937 int slot_in_qp = -1;
938 transport_->DecodeSlotIndex(slot, &client_id, &qp_index, &slot_in_qp);
939 auto* commit = transport_->RequestCommitAt(slot);
940 if (commit->state.load(std::memory_order_acquire) != petps::kRcSlotReady) {
941 if (profile_enabled) {
942 profile_.not_ready_slots.fetch_add(1, std::memory_order_relaxed);
943 poller_profiles_.at(static_cast<std::size_t>(thread_id))
944 ->not_ready_slots.fetch_add(1, std::memory_order_relaxed);
945 }
946 return false;
947 }
948 const std::uint64_t seq = commit->seq.load(std::memory_order_acquire);
949 if (seq == 0) {
950 if (profile_enabled) {
951 profile_.zero_seq_ready.fetch_add(1, std::memory_order_relaxed);
952 }
953 return false;
954 }
955 if (seq == last_seq_[static_cast<std::size_t>(slot)]) {
956 if (profile_enabled) {
957 profile_.duplicate_seq_ready.fetch_add(1, std::memory_order_relaxed);
958 poller_profiles_.at(static_cast<std::size_t>(thread_id))
959 ->duplicate_seq_ready.fetch_add(1, std::memory_order_relaxed);
960 }
961 return false;
962 }
963 if (GetPayloadOffloadEnabled() &&
964 seq == inflight_seq_[static_cast<std::size_t>(slot)]) {
965 if (profile_enabled) {
966 profile_.inflight_seq_ready.fetch_add(1, std::memory_order_relaxed);
967 poller_profiles_.at(static_cast<std::size_t>(thread_id))
968 ->inflight_seq_ready.fetch_add(1, std::memory_order_relaxed);
969 }
970 return false;
971 }
972 if (profile_enabled) {
973 profile_.ready_slots.fetch_add(1, std::memory_order_relaxed);
974 }
975
976 auto* descriptor = transport_->RequestDescriptorAt(slot);
977 std::string error;
978 if (!petps::ValidateRequestDescriptor(
979 *descriptor,
980 transport_->config().request_slot_bytes,
981 transport_->config().response_slot_bytes,
982 &error)) {
983 LOG(ERROR) << "component=rdma_rc_server event=invalid_descriptor"
984 << " shard=" << shard_id_ << " slot=" << slot
985 << " thread_id=" << thread_id << " seq=" << seq
986 << " descriptor_seq=" << descriptor->seq
987 << " client_id=" << descriptor->client_id
988 << " qp=" << descriptor->qp_index << " op=" << descriptor->op
989 << " key_count=" << descriptor->key_count
990 << " payload_bytes=" << descriptor->payload_bytes
991 << " response_bytes=" << descriptor->response_bytes
992 << " error=\"" << error << "\"";
993 if (profile_enabled) {
994 profile_.invalid_descriptor.fetch_add(1, std::memory_order_relaxed);
995 }
996 last_seq_[static_cast<std::size_t>(slot)] = seq;
997 commit->state.store(0, std::memory_order_release);
998 return true;
999 }
1000 if (descriptor->client_id != static_cast<std::uint32_t>(client_id) ||
1001 descriptor->qp_index != static_cast<std::uint32_t>(qp_index)) {
1002 LOG(ERROR) << "component=rdma_rc_server event=slot_descriptor_mismatch"
1003 << " shard=" << shard_id_ << " slot=" << slot
1004 << " thread_id=" << thread_id
1005 << " slot_client_id=" << client_id << " slot_qp=" << qp_index
1006 << " descriptor_client_id=" << descriptor->client_id
1007 << " descriptor_qp=" << descriptor->qp_index << " seq=" << seq;
1008 if (profile_enabled) {
1009 profile_.invalid_descriptor.fetch_add(1, std::memory_order_relaxed);
1010 }
1011 last_seq_[static_cast<std::size_t>(slot)] = seq;
1012 commit->state.store(0, std::memory_order_release);
1013 return true;
1014 }
1015
1016 auto response =
1017 transport_->OpenClientResponse(client_id, qp_index, slot_in_qp);
1018 const char* payload = transport_->RequestPayloadAt(slot);
1019 VLOG(1) << "component=rdma_rc_server event=consume shard=" << shard_id_
1020 << " slot=" << slot << " client_id=" << descriptor->client_id
1021 << " qp=" << descriptor->qp_index << " seq=" << seq << " op="
1022 << descriptor->op << " key_count=" << descriptor->key_count
1023 << " payload_bytes=" << descriptor->payload_bytes
1024 << " response_bytes=" << descriptor->response_bytes;
1025 response.status->status =
1026 static_cast<std::int32_t>(petps::RpcStatus::kInvalidPayload);
1027 response.status->response_bytes = 0;
1028
1029 if (descriptor->shard_id != static_cast<std::uint32_t>(shard_id_)) {
1030 LOG(ERROR) << "component=rdma_rc_server event=wrong_shard"
1031 << " expected_shard=" << shard_id_
1032 << " actual_shard=" << descriptor->shard_id << " slot=" << slot
1033 << " client_id=" << descriptor->client_id
1034 << " qp=" << descriptor->qp_index << " seq=" << seq << " op="
1035 << descriptor->op << " key_count=" << descriptor->key_count;
1036 if (profile_enabled) {
1037 profile_.wrong_shard.fetch_add(1, std::memory_order_relaxed);
1038 }
1039 response.status->status =
1040 static_cast<std::int32_t>(petps::RpcStatus::kWrongShard);
1041 } else if (descriptor->op ==
1042 static_cast<std::uint16_t>(petps::RcOp::kGet)) {
1043 if (GetPayloadOffloadEnabled()) {
1044 const GetPayloadTask task{
1045 slot,
1046 client_id,
1047 qp_index,
1048 slot_in_qp,
1049 thread_id,
1050 seq,
1051 *descriptor,
1052 payload,
1053 response,
1054 };
1055 if (!EnqueueGetPayloadTask(task)) {
1056 return false;
1057 }
1058 inflight_seq_[static_cast<std::size_t>(slot)] = seq;
1059 return true;
1060 } else {
1061 const std::uint64_t handle_start_ns = profile_enabled ? NowNs() : 0;
1062 const bool payload_written_direct =
1063 HandleGet(*descriptor, payload, &response, thread_id, slot_in_qp);
1064 if (profile_enabled) {
1065 profile_.handled_get.fetch_add(1, std::memory_order_relaxed);
1066 profile_.handle_get_ns.fetch_add(
1067 NowNs() - handle_start_ns, std::memory_order_relaxed);
1068 poller_profiles_.at(static_cast<std::size_t>(thread_id))
1069 ->handled_get.fetch_add(1, std::memory_order_relaxed);
1070 }
1071 if (payload_written_direct) {
1072 CompleteResponseStatusOnlyForSlot(
1073 slot,
1074 client_id,
1075 qp_index,
1076 slot_in_qp,
1077 response,
1078 seq,
1079 profile_enabled);
1080 return true;
1081 }
1082 }
1083 } else if (descriptor->op ==
1084 static_cast<std::uint16_t>(petps::RcOp::kPut)) {
1085 const std::uint64_t handle_start_ns = profile_enabled ? NowNs() : 0;
1086 HandlePut(*descriptor, payload, &response, thread_id);
1087 if (profile_enabled) {
1088 profile_.handled_put.fetch_add(1, std::memory_order_relaxed);
1089 profile_.handle_put_ns.fetch_add(
1090 NowNs() - handle_start_ns, std::memory_order_relaxed);
1091 }
1092 } else if (descriptor->op ==
1093 static_cast<std::uint16_t>(petps::RcOp::kUpdate)) {
1094 const std::uint64_t handle_start_ns = profile_enabled ? NowNs() : 0;
1095 HandleUpdate(*descriptor, payload, &response, thread_id);
1096 if (profile_enabled) {
1097 profile_.handled_update.fetch_add(1, std::memory_order_relaxed);
1098 profile_.handle_update_ns.fetch_add(
1099 NowNs() - handle_start_ns, std::memory_order_relaxed);
1100 }
1101 } else if (descriptor->op ==
1102 static_cast<std::uint16_t>(petps::RcOp::kInitTable)) {
1103 const std::uint64_t handle_start_ns = profile_enabled ? NowNs() : 0;
1104 HandleInitTable(*descriptor, payload, &response);
1105 if (profile_enabled) {
1106 profile_.handled_init.fetch_add(1, std::memory_order_relaxed);
1107 profile_.handle_init_ns.fetch_add(
1108 NowNs() - handle_start_ns, std::memory_order_relaxed);
1109 }
1110 }
1111
1112 CompleteResponseForSlot(
1113 slot, client_id, qp_index, slot_in_qp, response, seq, profile_enabled);
1114 return true;
1115 }
1116
1117 void ScanAssignedSlots(
1118 int thread_id,
1119 int worker_id,
1120 int worker_count,
1121 bool profile_enabled,
1122 std::uint64_t* scanned_slots,
1123 std::uint64_t* ready_slots) {
1124 const int qp_count = transport_->config().qps_per_client_per_shard;
1125 const int slots_per_qp = transport_->config().slots_per_qp;
1126 const int num_clients = transport_->config().num_clients;
1127 const int lane_slots = num_clients * slots_per_qp;
1128 for (int qp_index = thread_id; qp_index < qp_count;
1129 qp_index += thread_count_) {
1130 for (int lane_slot = worker_id; lane_slot < lane_slots;
1131 lane_slot += worker_count) {
1132 const int client_id = lane_slot / slots_per_qp;
1133 const int slot_in_qp = lane_slot % slots_per_qp;
1134 const int slot_index =
1135 transport_->SlotIndex(client_id, qp_index, slot_in_qp);
1136 ++(*scanned_slots);
1137 if (ProcessSlot(slot_index, thread_id, profile_enabled)) {
1138 ++(*ready_slots);
1139 }
1140 }
1141 }
1142 }
1143
1144 void CoroutineSlotScanner(
1145 boost::coroutines2::coroutine<void>::push_type& sink,
1146 int thread_id,
1147 int worker_id,
1148 int worker_count) {
1149 while (true) {
1150 const bool profile_enabled = FLAGS_rdma_rc_profile_interval_ms > 0;
1151 const std::uint64_t poll_start_ns = profile_enabled ? NowNs() : 0;
1152 std::uint64_t scanned_slots = 0;
1153 std::uint64_t ready_slots = 0;
1154 DrainGetPayloadCompletions(thread_id, profile_enabled);
1155 ScanAssignedSlots(
1156 thread_id,
1157 worker_id,
1158 worker_count,
1159 profile_enabled,
1160 &scanned_slots,
1161 &ready_slots);
1162 DrainGetPayloadCompletions(thread_id, profile_enabled);
1163 if (profile_enabled) {
1164 profile_.scan_rounds.fetch_add(1, std::memory_order_relaxed);
1165 profile_.scanned_slots.fetch_add(
1166 scanned_slots, std::memory_order_relaxed);
1167 if (ready_slots == 0) {
1168 profile_.empty_scan_rounds.fetch_add(1, std::memory_order_relaxed);
1169 }
1170 UpdateMax(&profile_.max_ready_per_round, ready_slots);
1171 const std::uint64_t poll_loop_ns = NowNs() - poll_start_ns;
1172 profile_.poll_loop_ns.fetch_add(
1173 poll_loop_ns, std::memory_order_relaxed);
1174 auto& poller =
1175 *poller_profiles_.at(static_cast<std::size_t>(thread_id));
1176 poller.scan_rounds.fetch_add(1, std::memory_order_relaxed);
1177 poller.scanned_slots.fetch_add(
1178 scanned_slots, std::memory_order_relaxed);
1179 poller.ready_slots.fetch_add(ready_slots, std::memory_order_relaxed);
1180 poller.poll_loop_ns.fetch_add(poll_loop_ns, std::memory_order_relaxed);
1181 }
1182 sink();
1183 }
1184 }
1185
1186 void RunCoroutinePollingThread(int thread_id, int coroutines_per_thread) {
1187 using Coroutine = boost::coroutines2::coroutine<void>;
1188 std::vector<std::unique_ptr<Coroutine::pull_type>> coroutines;
1189 coroutines.reserve(static_cast<std::size_t>(coroutines_per_thread));
1190 for (int coroutine_id = 0; coroutine_id < coroutines_per_thread;
1191 ++coroutine_id) {
1192 coroutines.emplace_back(std::make_unique<Coroutine::pull_type>(
1193 [this, thread_id, coroutine_id, coroutines_per_thread](
1194 Coroutine::push_type& sink) {
1195 CoroutineSlotScanner(
1196 sink, thread_id, coroutine_id, coroutines_per_thread);
1197 }));
1198 }
1199 while (true) {
1200 for (auto& coroutine : coroutines) {
1201 (*coroutine)();
1202 }
1203 MaybeReportProfile(thread_id);
1204 std::this_thread::yield();
1205 }
1206 }
1207
1208 CachePS* cache_ps_ = nullptr;
1209 int thread_count_ = 1;
1210 int shard_id_ = 0;
1211 std::unique_ptr<petps::RcShardServerTransport> transport_;
1212 petps::RdmaControlPlaneClient control_plane_client_;
1213 std::vector<std::thread> threads_;
1214 std::vector<std::uint64_t> last_seq_;
1215 std::vector<std::uint64_t> inflight_seq_;
1216 std::vector<std::unique_ptr<PollerProfile>> poller_profiles_;
1217 int get_payload_worker_count_ = 0;
1218 std::vector<std::thread> get_payload_workers_;
1219 std::mutex get_payload_mu_;
1220 std::condition_variable get_payload_cv_;
1221 std::deque<GetPayloadTask> get_payload_tasks_;
1222 std::vector<std::deque<GetPayloadCompletion>> get_payload_completions_;
1223 std::atomic<int> started_threads_{0};
1224 std::atomic<bool> ready_published_{false};
1225 ProfileCounters profile_;
1226 };
1227
1228 } // namespace
1229
1230 int main(int argc, char* argv[]) {
1231 folly::init(&argc, &argv);
1232 if (ShouldTraceRdmaGet()) {
1233 std::cerr << "component=rdma_get_trace side=server event=enabled interval="
1234 << RdmaGetTraceInterval() << std::endl;
1235 }
1236 xmh::Reporter::StartReportThread();
1237
1238 base::PMMmapRegisterCenter::GetConfig().backend =
1239 base::PMMmapRegisterCenter::BackendFromUseDram(FLAGS_use_dram);
1240 base::PMMmapRegisterCenter::GetConfig().numa_id = FLAGS_numa_id;
1241
1242 base::global_socket_id = FLAGS_numa_id;
1243 LOG(INFO) << "set NUMA ID = " << FLAGS_numa_id;
1244
1245 const std::string config_path =
1246 FLAGS_config_path.empty()
1247 ? base::ResolveRecStoreConfigPath().string()
1248 : FLAGS_config_path;
1249 std::ifstream config_file(config_path);
1250 if (!config_file.is_open()) {
1251 LOG(FATAL) << "Cannot open config file: " << config_path;
1252 }
1253
1254 nlohmann::json config;
1255 config_file >> config;
1256 if (config.contains("cache_ps") && config["cache_ps"].is_object() &&
1257 config["cache_ps"].contains("base_kv_config")) {
1258 NormalizeDramValuePath(&config["cache_ps"]["base_kv_config"]);
1259 }
1260 if (config.contains("distributed_client") &&
1261 config["distributed_client"].is_object() &&
1262 config["distributed_client"].contains("base_kv_config")) {
1263 NormalizeDramValuePath(&config["distributed_client"]["base_kv_config"]);
1264 }
1265 std::unique_ptr<petps::RdmaControlPlaneServer> control_plane_server;
1266 if (FLAGS_global_id == 0) {
1267 control_plane_server = std::make_unique<petps::RdmaControlPlaneServer>(
1268 petps::RdmaControlPlaneEndpoint{
1269 FLAGS_rdma_control_plane_host,
1270 FLAGS_rdma_control_plane_port,
1271 FLAGS_rdma_control_plane_timeout_ms,
1272 });
1273 control_plane_server->Start();
1274 LOG(INFO) << "component=rdma_control_plane event=listening"
1275 << " server_id=0"
1276 << " host=" << FLAGS_rdma_control_plane_host
1277 << " port=" << FLAGS_rdma_control_plane_port;
1278 }
1279 auto cache_ps = std::make_unique<CachePS>(config["cache_ps"]);
1280 const int shard_id = ResolveShardId(config);
1281 auto ps = std::make_unique<PetPSServer>(
1282 cache_ps.get(), FLAGS_thread_num, shard_id, NamespaceToken());
1283 ps->Run();
1284 while (true) {
1285 std::this_thread::sleep_for(std::chrono::seconds(1));
1286 }
1287 return 0;
1288 }
1289