ps/rdma/petps_server.cc
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #include <folly/init/Init.h> | ||
| 2 | |||
| 3 | #include <boost/coroutine2/all.hpp> | ||
| 4 | |||
| 5 | #include <atomic> | ||
| 6 | #include <array> | ||
| 7 | #include <algorithm> | ||
| 8 | #include <chrono> | ||
| 9 | #include <condition_variable> | ||
| 10 | #include <deque> | ||
| 11 | #include <cstdint> | ||
| 12 | #include <cstdlib> | ||
| 13 | #include <cstring> | ||
| 14 | #include <fstream> | ||
| 15 | #include <iostream> | ||
| 16 | #include <limits> | ||
| 17 | #include <memory> | ||
| 18 | #include <mutex> | ||
| 19 | #include <stdexcept> | ||
| 20 | #include <string> | ||
| 21 | #include <thread> | ||
| 22 | #include <vector> | ||
| 23 | |||
| 24 | #include "base/bind_core.h" | ||
| 25 | #include "base/config.h" | ||
| 26 | #include "base/log.h" | ||
| 27 | #include "base/timer.h" | ||
| 28 | #include "memory/shm_file.h" | ||
| 29 | #include "ps/rdma/rdma_common.h" | ||
| 30 | #include "ps/base/cache_ps_impl.h" | ||
| 31 | #include "ps/rdma/control_plane.h" | ||
| 32 | #include "ps/rdma/rc_options.h" | ||
| 33 | #include "ps/rdma/rc_transport.h" | ||
| 34 | #include "ps/rdma/rdma_protocol.h" | ||
| 35 | #include "ps/rdma/rdma_status.h" | ||
| 36 | |||
| 37 | DEFINE_string(config_path, "", "config file path"); | ||
| 38 | DEFINE_int32(thread_num, 1, "RC write poll thread count"); | ||
| 39 | DECLARE_int32(global_id); | ||
| 40 | DECLARE_int32(num_server_processes); | ||
| 41 | DECLARE_int32(num_client_processes); | ||
| 42 | DEFINE_int32(value_size, 128, "embedding row bytes"); | ||
| 43 | DEFINE_int32(max_kv_num_per_request, 500, "max keys per request"); | ||
| 44 | DEFINE_bool(use_dram, false, "unused compatibility flag"); | ||
| 45 | DEFINE_int32(numa_id, 0, "NUMA node id for mmap and core binding"); | ||
| 46 | |||
| 47 | namespace { | ||
| 48 | |||
| 49 | using petps::Exchange; | ||
| 50 | using petps::NamespaceToken; | ||
| 51 | using petps::NowNs; | ||
| 52 | |||
| 53 | constexpr std::size_t kMaxDirectSgesPerWr = 32; | ||
| 54 | |||
| 55 | ✗ | bool ShouldTraceRdmaGet() { | |
| 56 | ✗ | static const bool enabled = [] { | |
| 57 | ✗ | const char* env = std::getenv("RECSTORE_RDMA_GET_TRACE"); | |
| 58 | ✗ | return env != nullptr && std::string(env) != "0"; | |
| 59 | ✗ | }(); | |
| 60 | ✗ | return enabled; | |
| 61 | } | ||
| 62 | |||
| 63 | ✗ | std::uint64_t RdmaGetTraceInterval() { | |
| 64 | ✗ | static const std::uint64_t interval = [] { | |
| 65 | ✗ | const char* env = std::getenv("RECSTORE_RDMA_GET_TRACE_INTERVAL"); | |
| 66 | ✗ | if (env == nullptr) { | |
| 67 | ✗ | return std::uint64_t{5000}; | |
| 68 | } | ||
| 69 | const auto parsed = | ||
| 70 | ✗ | static_cast<std::uint64_t>(std::strtoull(env, nullptr, 10)); | |
| 71 | ✗ | return parsed == 0 ? std::uint64_t{5000} : parsed; | |
| 72 | ✗ | }(); | |
| 73 | ✗ | return interval; | |
| 74 | } | ||
| 75 | |||
| 76 | ✗ | std::string TimestampNow() { | |
| 77 | ✗ | const auto now = std::chrono::system_clock::now().time_since_epoch(); | |
| 78 | return std::to_string( | ||
| 79 | ✗ | std::chrono::duration_cast<std::chrono::microseconds>(now).count()); | |
| 80 | } | ||
| 81 | |||
| 82 | ✗ | int ResolveShardId(const nlohmann::json& config) { | |
| 83 | ✗ | const int default_shard = FLAGS_global_id; | |
| 84 | ✗ | if (!config.contains("cache_ps") || !config["cache_ps"].is_object()) { | |
| 85 | ✗ | return default_shard; | |
| 86 | } | ||
| 87 | ✗ | const auto& cache_ps = config["cache_ps"]; | |
| 88 | ✗ | if (cache_ps.contains("servers") && cache_ps["servers"].is_array()) { | |
| 89 | ✗ | for (const auto& server : cache_ps["servers"]) { | |
| 90 | ✗ | if (server.value("shard", -1) == FLAGS_global_id) { | |
| 91 | ✗ | return server.value("shard", default_shard); | |
| 92 | } | ||
| 93 | } | ||
| 94 | } | ||
| 95 | ✗ | return default_shard; | |
| 96 | } | ||
| 97 | |||
| 98 | ✗ | void NormalizeDramValuePath(nlohmann::json* base_kv_config) { | |
| 99 | ✗ | if (base_kv_config == nullptr || !base_kv_config->is_object()) { | |
| 100 | ✗ | return; | |
| 101 | } | ||
| 102 | ✗ | if (!base_kv_config->contains("value") || | |
| 103 | ✗ | !(*base_kv_config)["value"].is_object()) { | |
| 104 | ✗ | return; | |
| 105 | } | ||
| 106 | ✗ | auto& value_cfg = (*base_kv_config)["value"]; | |
| 107 | const std::string value_type = | ||
| 108 | ✗ | value_cfg.value("type", std::string("DRAM_VALUE_STORE")); | |
| 109 | ✗ | if (value_type != "DRAM_VALUE_STORE") { | |
| 110 | ✗ | return; | |
| 111 | } | ||
| 112 | ✗ | const std::string path = value_cfg.value("path", std::string()); | |
| 113 | ✗ | if (path.empty() || path.rfind("/dev/shm", 0) == 0) { | |
| 114 | ✗ | return; | |
| 115 | } | ||
| 116 | ✗ | value_cfg["path"] = "/dev/shm/recstore_rdma_rc_" + TimestampNow() + "/value"; | |
| 117 | ✗ | } | |
| 118 | |||
| 119 | class PetPSServer { | ||
| 120 | public: | ||
| 121 | ✗ | PetPSServer(CachePS* cache_ps, | |
| 122 | int thread_count, | ||
| 123 | int shard_id, | ||
| 124 | const std::string& namespace_token) | ||
| 125 | ✗ | : cache_ps_(cache_ps), | |
| 126 | ✗ | thread_count_(thread_count), | |
| 127 | ✗ | shard_id_(shard_id), | |
| 128 | ✗ | control_plane_client_(petps::RdmaControlPlaneEndpoint{ | |
| 129 | FLAGS_rdma_control_plane_host, | ||
| 130 | FLAGS_rdma_control_plane_port, | ||
| 131 | FLAGS_rdma_control_plane_timeout_ms, | ||
| 132 | ✗ | }) { | |
| 133 | ✗ | petps::RcTransportConfig config; | |
| 134 | ✗ | config.shard_id = shard_id_; | |
| 135 | ✗ | config.num_clients = | |
| 136 | ✗ | FLAGS_rdma_rc_num_logical_clients >= 0 | |
| 137 | ✗ | ? FLAGS_rdma_rc_num_logical_clients | |
| 138 | : FLAGS_num_client_processes; | ||
| 139 | ✗ | config.qps_per_client_per_shard = FLAGS_rdma_rc_qps_per_client_per_shard; | |
| 140 | ✗ | config.slots_per_qp = FLAGS_rdma_rc_slots_per_qp; | |
| 141 | ✗ | config.request_slot_bytes = | |
| 142 | ✗ | static_cast<std::size_t>(FLAGS_rdma_rc_request_slot_bytes); | |
| 143 | ✗ | config.response_slot_bytes = | |
| 144 | ✗ | static_cast<std::size_t>(FLAGS_rdma_rc_response_slot_bytes); | |
| 145 | ✗ | config.control_plane_host = FLAGS_rdma_control_plane_host; | |
| 146 | ✗ | config.control_plane_port = FLAGS_rdma_control_plane_port; | |
| 147 | ✗ | config.control_plane_timeout_ms = FLAGS_rdma_control_plane_timeout_ms; | |
| 148 | ✗ | config.namespace_token = namespace_token; | |
| 149 | ✗ | transport_ = std::make_unique<petps::RcShardServerTransport>(config); | |
| 150 | ✗ | const auto backing = cache_ps_->GetRDMABackingRegion(); | |
| 151 | ✗ | if (backing.data != nullptr && backing.size > 0) { | |
| 152 | ✗ | transport_->RegisterLocalMemoryRegion(backing.data, backing.size); | |
| 153 | ✗ | LOG(INFO) << "component=rdma_rc_server event=value_region_registered" | |
| 154 | ✗ | << " bytes=" << backing.size; | |
| 155 | } else { | ||
| 156 | ✗ | LOG(INFO) << "component=rdma_rc_server event=value_region_unavailable"; | |
| 157 | } | ||
| 158 | ✗ | last_seq_.assign( | |
| 159 | ✗ | static_cast<std::size_t>(transport_->TotalSlots()), std::uint64_t{0}); | |
| 160 | ✗ | inflight_seq_.assign( | |
| 161 | ✗ | static_cast<std::size_t>(transport_->TotalSlots()), std::uint64_t{0}); | |
| 162 | ✗ | get_payload_worker_count_ = FLAGS_rdma_rc_server_get_workers; | |
| 163 | ✗ | if (get_payload_worker_count_ < 0) { | |
| 164 | ✗ | LOG(FATAL) << "--rdma_rc_server_get_workers must be non-negative"; | |
| 165 | } | ||
| 166 | ✗ | poller_profiles_.reserve( | |
| 167 | ✗ | static_cast<std::size_t>(std::max(1, thread_count_))); | |
| 168 | ✗ | for (int i = 0; i < std::max(1, thread_count_); ++i) { | |
| 169 | ✗ | poller_profiles_.emplace_back(std::make_unique<PollerProfile>()); | |
| 170 | } | ||
| 171 | ✗ | get_payload_completions_.resize( | |
| 172 | ✗ | static_cast<std::size_t>(std::max(1, thread_count_))); | |
| 173 | ✗ | } | |
| 174 | |||
| 175 | ✗ | void Run() { | |
| 176 | ✗ | StartGetPayloadWorkers(); | |
| 177 | ✗ | for (int i = 0; i < thread_count_; ++i) { | |
| 178 | ✗ | threads_.emplace_back(&PetPSServer::PollingThread, this, i); | |
| 179 | } | ||
| 180 | ✗ | } | |
| 181 | |||
| 182 | private: | ||
| 183 | struct GetPayloadTask { | ||
| 184 | int slot = -1; | ||
| 185 | int client_id = -1; | ||
| 186 | int qp_index = -1; | ||
| 187 | int slot_in_qp = -1; | ||
| 188 | int poll_thread_id = -1; | ||
| 189 | std::uint64_t seq = 0; | ||
| 190 | petps::RequestDescriptor descriptor{}; | ||
| 191 | const char* payload = nullptr; | ||
| 192 | petps::RcShardServerTransport::ResponseView response{}; | ||
| 193 | }; | ||
| 194 | |||
| 195 | struct GetPayloadCompletion { | ||
| 196 | int slot = -1; | ||
| 197 | int client_id = -1; | ||
| 198 | int qp_index = -1; | ||
| 199 | int slot_in_qp = -1; | ||
| 200 | int poll_thread_id = -1; | ||
| 201 | std::uint64_t seq = 0; | ||
| 202 | petps::RcShardServerTransport::ResponseView response{}; | ||
| 203 | bool payload_written_direct = false; | ||
| 204 | }; | ||
| 205 | |||
| 206 | struct ProfileCounters { | ||
| 207 | std::atomic<std::uint64_t> scan_rounds{0}; | ||
| 208 | std::atomic<std::uint64_t> scanned_slots{0}; | ||
| 209 | std::atomic<std::uint64_t> ready_slots{0}; | ||
| 210 | std::atomic<std::uint64_t> not_ready_slots{0}; | ||
| 211 | std::atomic<std::uint64_t> zero_seq_ready{0}; | ||
| 212 | std::atomic<std::uint64_t> duplicate_seq_ready{0}; | ||
| 213 | std::atomic<std::uint64_t> inflight_seq_ready{0}; | ||
| 214 | std::atomic<std::uint64_t> empty_scan_rounds{0}; | ||
| 215 | std::atomic<std::uint64_t> max_ready_per_round{0}; | ||
| 216 | std::atomic<std::uint64_t> handled_get{0}; | ||
| 217 | std::atomic<std::uint64_t> handled_put{0}; | ||
| 218 | std::atomic<std::uint64_t> handled_update{0}; | ||
| 219 | std::atomic<std::uint64_t> handled_init{0}; | ||
| 220 | std::atomic<std::uint64_t> invalid_descriptor{0}; | ||
| 221 | std::atomic<std::uint64_t> wrong_shard{0}; | ||
| 222 | std::atomic<std::uint64_t> handle_get_ns{0}; | ||
| 223 | std::atomic<std::uint64_t> get_batch_get_ns{0}; | ||
| 224 | std::atomic<std::uint64_t> get_index_lookup_ns{0}; | ||
| 225 | std::atomic<std::uint64_t> get_zero_fill_ns{0}; | ||
| 226 | std::atomic<std::uint64_t> get_row_copy_ns{0}; | ||
| 227 | std::atomic<std::uint64_t> get_rows{0}; | ||
| 228 | std::atomic<std::uint64_t> get_value_bytes{0}; | ||
| 229 | std::atomic<std::uint64_t> get_missing_rows{0}; | ||
| 230 | std::atomic<std::uint64_t> get_direct_sg{0}; | ||
| 231 | std::atomic<std::uint64_t> get_direct_sg_fallback{0}; | ||
| 232 | std::atomic<std::uint64_t> get_direct_sg_ns{0}; | ||
| 233 | std::atomic<std::uint64_t> get_direct_sg_wr{0}; | ||
| 234 | std::atomic<std::uint64_t> handle_put_ns{0}; | ||
| 235 | std::atomic<std::uint64_t> handle_update_ns{0}; | ||
| 236 | std::atomic<std::uint64_t> handle_init_ns{0}; | ||
| 237 | std::atomic<std::uint64_t> complete_response_ns{0}; | ||
| 238 | std::atomic<std::uint64_t> poll_loop_ns{0}; | ||
| 239 | std::atomic<std::uint64_t> next_report_ns{0}; | ||
| 240 | }; | ||
| 241 | |||
| 242 | struct PollerProfile { | ||
| 243 | std::atomic<std::uint64_t> scan_rounds{0}; | ||
| 244 | std::atomic<std::uint64_t> scanned_slots{0}; | ||
| 245 | std::atomic<std::uint64_t> ready_slots{0}; | ||
| 246 | std::atomic<std::uint64_t> not_ready_slots{0}; | ||
| 247 | std::atomic<std::uint64_t> duplicate_seq_ready{0}; | ||
| 248 | std::atomic<std::uint64_t> inflight_seq_ready{0}; | ||
| 249 | std::atomic<std::uint64_t> handled_get{0}; | ||
| 250 | std::atomic<std::uint64_t> poll_loop_ns{0}; | ||
| 251 | }; | ||
| 252 | |||
| 253 | static void | ||
| 254 | ✗ | UpdateMax(std::atomic<std::uint64_t>* value, std::uint64_t candidate) { | |
| 255 | ✗ | std::uint64_t current = value->load(std::memory_order_relaxed); | |
| 256 | ✗ | while (candidate > current && | |
| 257 | ✗ | !value->compare_exchange_weak( | |
| 258 | current, candidate, std::memory_order_relaxed)) { | ||
| 259 | } | ||
| 260 | ✗ | } | |
| 261 | |||
| 262 | ✗ | void MaybeReportProfile(int thread_id) { | |
| 263 | ✗ | if (FLAGS_rdma_rc_profile_interval_ms <= 0 || thread_id != 0) { | |
| 264 | ✗ | return; | |
| 265 | } | ||
| 266 | ✗ | const std::uint64_t now = NowNs(); | |
| 267 | ✗ | const std::uint64_t interval = | |
| 268 | ✗ | static_cast<std::uint64_t>(FLAGS_rdma_rc_profile_interval_ms) * 1000000; | |
| 269 | std::uint64_t expected = | ||
| 270 | ✗ | profile_.next_report_ns.load(std::memory_order_relaxed); | |
| 271 | ✗ | if (expected == 0) { | |
| 272 | ✗ | profile_.next_report_ns.compare_exchange_strong( | |
| 273 | expected, now + interval, std::memory_order_relaxed); | ||
| 274 | ✗ | return; | |
| 275 | } | ||
| 276 | ✗ | if (now < expected || | |
| 277 | ✗ | !profile_.next_report_ns.compare_exchange_strong( | |
| 278 | expected, now + interval, std::memory_order_relaxed)) { | ||
| 279 | ✗ | return; | |
| 280 | } | ||
| 281 | |||
| 282 | ✗ | const std::uint64_t scan_rounds = Exchange(&profile_.scan_rounds); | |
| 283 | ✗ | const std::uint64_t scanned_slots = Exchange(&profile_.scanned_slots); | |
| 284 | ✗ | const std::uint64_t ready_slots = Exchange(&profile_.ready_slots); | |
| 285 | ✗ | const std::uint64_t not_ready_slots = Exchange(&profile_.not_ready_slots); | |
| 286 | ✗ | const std::uint64_t zero_seq_ready = Exchange(&profile_.zero_seq_ready); | |
| 287 | const std::uint64_t duplicate_seq_ready = | ||
| 288 | ✗ | Exchange(&profile_.duplicate_seq_ready); | |
| 289 | const std::uint64_t inflight_seq_ready = | ||
| 290 | ✗ | Exchange(&profile_.inflight_seq_ready); | |
| 291 | const std::uint64_t empty_scan_rounds = | ||
| 292 | ✗ | Exchange(&profile_.empty_scan_rounds); | |
| 293 | const std::uint64_t max_ready_per_round = | ||
| 294 | ✗ | Exchange(&profile_.max_ready_per_round); | |
| 295 | ✗ | const std::uint64_t handled_get = Exchange(&profile_.handled_get); | |
| 296 | ✗ | const std::uint64_t handled_put = Exchange(&profile_.handled_put); | |
| 297 | ✗ | const std::uint64_t handled_update = Exchange(&profile_.handled_update); | |
| 298 | ✗ | const std::uint64_t handled_init = Exchange(&profile_.handled_init); | |
| 299 | ✗ | const std::uint64_t complete_count = | |
| 300 | ✗ | handled_get + handled_put + handled_update + handled_init; | |
| 301 | ✗ | const std::uint64_t handle_get_ns = Exchange(&profile_.handle_get_ns); | |
| 302 | ✗ | const std::uint64_t get_batch_get_ns = Exchange(&profile_.get_batch_get_ns); | |
| 303 | const std::uint64_t get_index_lookup_ns = | ||
| 304 | ✗ | Exchange(&profile_.get_index_lookup_ns); | |
| 305 | ✗ | const std::uint64_t get_zero_fill_ns = Exchange(&profile_.get_zero_fill_ns); | |
| 306 | ✗ | const std::uint64_t get_row_copy_ns = Exchange(&profile_.get_row_copy_ns); | |
| 307 | ✗ | const std::uint64_t get_rows = Exchange(&profile_.get_rows); | |
| 308 | ✗ | const std::uint64_t get_value_bytes = Exchange(&profile_.get_value_bytes); | |
| 309 | ✗ | const std::uint64_t get_missing_rows = Exchange(&profile_.get_missing_rows); | |
| 310 | ✗ | const std::uint64_t get_direct_sg = Exchange(&profile_.get_direct_sg); | |
| 311 | ✗ | const std::uint64_t get_direct_sg_ns = Exchange(&profile_.get_direct_sg_ns); | |
| 312 | ✗ | const std::uint64_t handle_put_ns = Exchange(&profile_.handle_put_ns); | |
| 313 | ✗ | const std::uint64_t handle_update_ns = Exchange(&profile_.handle_update_ns); | |
| 314 | ✗ | const std::uint64_t handle_init_ns = Exchange(&profile_.handle_init_ns); | |
| 315 | const std::uint64_t complete_response_ns = | ||
| 316 | ✗ | Exchange(&profile_.complete_response_ns); | |
| 317 | ✗ | const std::uint64_t poll_loop_ns = Exchange(&profile_.poll_loop_ns); | |
| 318 | ✗ | std::uint64_t poller_min_get = std::numeric_limits<std::uint64_t>::max(); | |
| 319 | ✗ | std::uint64_t poller_max_get = 0; | |
| 320 | ✗ | int poller_min_get_thread = -1; | |
| 321 | ✗ | int poller_max_get_thread = -1; | |
| 322 | ✗ | std::uint64_t poller_total_get = 0; | |
| 323 | ✗ | std::uint64_t poller_active = 0; | |
| 324 | ✗ | for (std::size_t i = 0; i < poller_profiles_.size(); ++i) { | |
| 325 | ✗ | auto& poller = *poller_profiles_[i]; | |
| 326 | ✗ | const std::uint64_t poller_get = Exchange(&poller.handled_get); | |
| 327 | ✗ | const std::uint64_t poller_scan_rounds = Exchange(&poller.scan_rounds); | |
| 328 | const std::uint64_t poller_scanned_slots = | ||
| 329 | ✗ | Exchange(&poller.scanned_slots); | |
| 330 | ✗ | const std::uint64_t poller_ready_slots = Exchange(&poller.ready_slots); | |
| 331 | const std::uint64_t poller_not_ready_slots = | ||
| 332 | ✗ | Exchange(&poller.not_ready_slots); | |
| 333 | const std::uint64_t poller_duplicate_seq_ready = | ||
| 334 | ✗ | Exchange(&poller.duplicate_seq_ready); | |
| 335 | const std::uint64_t poller_inflight_seq_ready = | ||
| 336 | ✗ | Exchange(&poller.inflight_seq_ready); | |
| 337 | ✗ | const std::uint64_t poller_poll_loop_ns = Exchange(&poller.poll_loop_ns); | |
| 338 | ✗ | if (poller_get > 0) { | |
| 339 | ✗ | ++poller_active; | |
| 340 | } | ||
| 341 | ✗ | poller_total_get += poller_get; | |
| 342 | ✗ | if (poller_get < poller_min_get) { | |
| 343 | ✗ | poller_min_get = poller_get; | |
| 344 | ✗ | poller_min_get_thread = static_cast<int>(i); | |
| 345 | } | ||
| 346 | ✗ | if (poller_get > poller_max_get) { | |
| 347 | ✗ | poller_max_get = poller_get; | |
| 348 | ✗ | poller_max_get_thread = static_cast<int>(i); | |
| 349 | } | ||
| 350 | std::cout | ||
| 351 | << "component=rdma_rc_server_poller_profile" | ||
| 352 | ✗ | << " shard=" << shard_id_ << " thread_id=" << i << " scan_rounds=" | |
| 353 | ✗ | << poller_scan_rounds << " scanned_slots=" << poller_scanned_slots | |
| 354 | ✗ | << " ready_slots=" << poller_ready_slots << " scan_hit_pct=" | |
| 355 | << (poller_scanned_slots == 0 | ||
| 356 | ? 0.0 | ||
| 357 | ✗ | : 100.0 * static_cast<double>(poller_ready_slots) / | |
| 358 | ✗ | static_cast<double>(poller_scanned_slots)) | |
| 359 | ✗ | << " not_ready_slots=" << poller_not_ready_slots | |
| 360 | ✗ | << " duplicate_seq_ready=" << poller_duplicate_seq_ready | |
| 361 | ✗ | << " inflight_seq_ready=" << poller_inflight_seq_ready | |
| 362 | ✗ | << " handled_get=" << poller_get << " poll_loop_avg_ns=" | |
| 363 | << (poller_scan_rounds == 0 | ||
| 364 | ? 0 | ||
| 365 | ✗ | : poller_poll_loop_ns / poller_scan_rounds) | |
| 366 | ✗ | << std::endl; | |
| 367 | } | ||
| 368 | ✗ | if (poller_min_get == std::numeric_limits<std::uint64_t>::max()) { | |
| 369 | ✗ | poller_min_get = 0; | |
| 370 | } | ||
| 371 | std::cout | ||
| 372 | << "component=rdma_rc_server_profile" | ||
| 373 | ✗ | << " shard=" << shard_id_ << " threads=" << thread_count_ | |
| 374 | ✗ | << " scan_rounds=" << scan_rounds << " scanned_slots=" << scanned_slots | |
| 375 | ✗ | << " ready_slots=" << ready_slots << " not_ready_slots=" | |
| 376 | ✗ | << not_ready_slots << " zero_seq_ready=" << zero_seq_ready | |
| 377 | ✗ | << " duplicate_seq_ready=" << duplicate_seq_ready | |
| 378 | ✗ | << " inflight_seq_ready=" << inflight_seq_ready | |
| 379 | ✗ | << " empty_scan_rounds=" << empty_scan_rounds << " scan_hit_pct=" | |
| 380 | << (scanned_slots == 0 ? 0.0 | ||
| 381 | ✗ | : 100.0 * static_cast<double>(ready_slots) / | |
| 382 | ✗ | static_cast<double>(scanned_slots)) | |
| 383 | ✗ | << " ready_round_pct=" | |
| 384 | << (scan_rounds == 0 | ||
| 385 | ? 0.0 | ||
| 386 | ✗ | : 100.0 * static_cast<double>(scan_rounds - empty_scan_rounds) / | |
| 387 | ✗ | static_cast<double>(scan_rounds)) | |
| 388 | ✗ | << " avg_ready_per_round=" | |
| 389 | << (scan_rounds == 0 ? 0.0 | ||
| 390 | ✗ | : static_cast<double>(ready_slots) / | |
| 391 | ✗ | static_cast<double>(scan_rounds)) | |
| 392 | ✗ | << " max_ready_per_round=" << max_ready_per_round | |
| 393 | ✗ | << " handled_get=" << handled_get << " handled_put=" << handled_put | |
| 394 | ✗ | << " handled_update=" << handled_update | |
| 395 | ✗ | << " handled_init=" << handled_init | |
| 396 | ✗ | << " invalid_descriptor=" << Exchange(&profile_.invalid_descriptor) | |
| 397 | ✗ | << " wrong_shard=" << Exchange(&profile_.wrong_shard) | |
| 398 | ✗ | << " handle_get_avg_ns=" | |
| 399 | ✗ | << (handled_get == 0 ? 0 : handle_get_ns / handled_get) | |
| 400 | ✗ | << " get_batch_get_avg_ns=" | |
| 401 | ✗ | << (handled_get == 0 ? 0 : get_batch_get_ns / handled_get) | |
| 402 | ✗ | << " get_index_lookup_avg_ns=" | |
| 403 | ✗ | << (handled_get == 0 ? 0 : get_index_lookup_ns / handled_get) | |
| 404 | ✗ | << " get_zero_fill_avg_ns=" | |
| 405 | ✗ | << (handled_get == 0 ? 0 : get_zero_fill_ns / handled_get) | |
| 406 | ✗ | << " get_row_copy_avg_ns=" | |
| 407 | ✗ | << (handled_get == 0 ? 0 : get_row_copy_ns / handled_get) | |
| 408 | ✗ | << " get_rows=" << get_rows << " get_value_bytes=" << get_value_bytes | |
| 409 | ✗ | << " get_missing_rows=" << get_missing_rows | |
| 410 | ✗ | << " get_direct_sg=" << get_direct_sg << " get_direct_sg_fallback=" | |
| 411 | ✗ | << Exchange(&profile_.get_direct_sg_fallback) | |
| 412 | ✗ | << " get_direct_sg_avg_ns=" | |
| 413 | ✗ | << (get_direct_sg == 0 ? 0 : get_direct_sg_ns / get_direct_sg) | |
| 414 | ✗ | << " get_direct_sg_wr=" << Exchange(&profile_.get_direct_sg_wr) | |
| 415 | ✗ | << " handle_put_avg_ns=" | |
| 416 | ✗ | << (handled_put == 0 ? 0 : handle_put_ns / handled_put) | |
| 417 | ✗ | << " handle_update_avg_ns=" | |
| 418 | ✗ | << (handled_update == 0 ? 0 : handle_update_ns / handled_update) | |
| 419 | ✗ | << " handle_init_avg_ns=" | |
| 420 | ✗ | << (handled_init == 0 ? 0 : handle_init_ns / handled_init) | |
| 421 | ✗ | << " complete_response_avg_ns=" | |
| 422 | ✗ | << (complete_count == 0 ? 0 : complete_response_ns / complete_count) | |
| 423 | ✗ | << " poll_loop_avg_ns=" | |
| 424 | ✗ | << (scan_rounds == 0 ? 0 : poll_loop_ns / scan_rounds) | |
| 425 | ✗ | << " poller_active=" << poller_active << " poller_total_get=" | |
| 426 | ✗ | << poller_total_get << " poller_min_get=" << poller_min_get | |
| 427 | ✗ | << " poller_min_get_thread=" << poller_min_get_thread | |
| 428 | ✗ | << " poller_max_get=" << poller_max_get | |
| 429 | ✗ | << " poller_max_get_thread=" << poller_max_get_thread << std::endl; | |
| 430 | } | ||
| 431 | |||
| 432 | ✗ | bool GetPayloadOffloadEnabled() const { | |
| 433 | ✗ | return get_payload_worker_count_ > 0; | |
| 434 | } | ||
| 435 | |||
| 436 | ✗ | std::size_t MaxGetPayloadQueueDepth() const { | |
| 437 | ✗ | return static_cast<std::size_t>(std::max(1, transport_->TotalSlots())); | |
| 438 | } | ||
| 439 | |||
| 440 | ✗ | void StartGetPayloadWorkers() { | |
| 441 | ✗ | if (!GetPayloadOffloadEnabled()) { | |
| 442 | ✗ | return; | |
| 443 | } | ||
| 444 | ✗ | for (int worker_id = 0; worker_id < get_payload_worker_count_; | |
| 445 | ++worker_id) { | ||
| 446 | ✗ | get_payload_workers_.emplace_back( | |
| 447 | ✗ | &PetPSServer::GetPayloadWorkerLoop, this, worker_id); | |
| 448 | } | ||
| 449 | ✗ | LOG(INFO) << "component=rdma_rc_server event=get_payload_workers_started" | |
| 450 | ✗ | << " count=" << get_payload_worker_count_; | |
| 451 | } | ||
| 452 | |||
| 453 | ✗ | void BindServerCore(int core_index) { | |
| 454 | ✗ | base::bind_core_with_env_offset(core_index); | |
| 455 | ✗ | } | |
| 456 | |||
| 457 | ✗ | bool EnqueueGetPayloadTask(const GetPayloadTask& task) { | |
| 458 | ✗ | std::lock_guard<std::mutex> guard(get_payload_mu_); | |
| 459 | ✗ | if (get_payload_tasks_.size() >= MaxGetPayloadQueueDepth()) { | |
| 460 | ✗ | return false; | |
| 461 | } | ||
| 462 | ✗ | get_payload_tasks_.push_back(task); | |
| 463 | ✗ | get_payload_cv_.notify_one(); | |
| 464 | ✗ | return true; | |
| 465 | ✗ | } | |
| 466 | |||
| 467 | ✗ | std::size_t PollThreadIndex(int poll_thread_id) const { | |
| 468 | ✗ | return static_cast<std::size_t>(poll_thread_id); | |
| 469 | } | ||
| 470 | |||
| 471 | ✗ | bool TryPopGetPayloadCompletion(int poll_thread_id, | |
| 472 | GetPayloadCompletion* completion) { | ||
| 473 | ✗ | std::lock_guard<std::mutex> guard(get_payload_mu_); | |
| 474 | auto& completions = | ||
| 475 | ✗ | get_payload_completions_.at(PollThreadIndex(poll_thread_id)); | |
| 476 | ✗ | if (completions.empty()) { | |
| 477 | ✗ | return false; | |
| 478 | } | ||
| 479 | ✗ | *completion = completions.front(); | |
| 480 | ✗ | completions.pop_front(); | |
| 481 | ✗ | return true; | |
| 482 | ✗ | } | |
| 483 | |||
| 484 | ✗ | void PushGetPayloadCompletion(const GetPayloadCompletion& completion) { | |
| 485 | ✗ | std::lock_guard<std::mutex> guard(get_payload_mu_); | |
| 486 | ✗ | get_payload_completions_.at(PollThreadIndex(completion.poll_thread_id)) | |
| 487 | ✗ | .push_back(completion); | |
| 488 | ✗ | } | |
| 489 | |||
| 490 | ✗ | void AccumulateFlatGetProfile(const CachePS::FlatGetProfile& get_profile) { | |
| 491 | ✗ | profile_.get_batch_get_ns.fetch_add( | |
| 492 | ✗ | get_profile.batch_get_ns, std::memory_order_relaxed); | |
| 493 | ✗ | profile_.get_index_lookup_ns.fetch_add( | |
| 494 | ✗ | get_profile.index_lookup_ns, std::memory_order_relaxed); | |
| 495 | ✗ | profile_.get_zero_fill_ns.fetch_add( | |
| 496 | ✗ | get_profile.zero_fill_ns, std::memory_order_relaxed); | |
| 497 | ✗ | profile_.get_row_copy_ns.fetch_add( | |
| 498 | ✗ | get_profile.row_copy_ns, std::memory_order_relaxed); | |
| 499 | ✗ | profile_.get_rows.fetch_add(get_profile.rows, std::memory_order_relaxed); | |
| 500 | ✗ | profile_.get_value_bytes.fetch_add( | |
| 501 | ✗ | get_profile.value_bytes, std::memory_order_relaxed); | |
| 502 | ✗ | profile_.get_missing_rows.fetch_add( | |
| 503 | ✗ | get_profile.missing_rows, std::memory_order_relaxed); | |
| 504 | ✗ | } | |
| 505 | |||
| 506 | ✗ | void GetPayloadWorkerLoop(int worker_id) { | |
| 507 | ✗ | BindServerCore(thread_count_ + worker_id); | |
| 508 | ✗ | LOG(INFO) << "component=rdma_rc_server event=get_payload_worker_ready" | |
| 509 | ✗ | << " worker_id=" << worker_id; | |
| 510 | while (true) { | ||
| 511 | ✗ | GetPayloadTask task; | |
| 512 | { | ||
| 513 | ✗ | std::unique_lock<std::mutex> lock(get_payload_mu_); | |
| 514 | ✗ | get_payload_cv_.wait(lock, [this] { | |
| 515 | ✗ | return !get_payload_tasks_.empty(); | |
| 516 | }); | ||
| 517 | ✗ | task = get_payload_tasks_.front(); | |
| 518 | ✗ | get_payload_tasks_.pop_front(); | |
| 519 | ✗ | } | |
| 520 | |||
| 521 | ✗ | const bool profile_enabled = FLAGS_rdma_rc_profile_interval_ms > 0; | |
| 522 | ✗ | const std::uint64_t handle_start_ns = profile_enabled ? NowNs() : 0; | |
| 523 | ✗ | const bool payload_written_direct = HandleGet( | |
| 524 | task.descriptor, | ||
| 525 | task.payload, | ||
| 526 | &task.response, | ||
| 527 | worker_id, | ||
| 528 | task.slot_in_qp); | ||
| 529 | ✗ | if (profile_enabled) { | |
| 530 | ✗ | profile_.handled_get.fetch_add(1, std::memory_order_relaxed); | |
| 531 | ✗ | profile_.handle_get_ns.fetch_add( | |
| 532 | ✗ | NowNs() - handle_start_ns, std::memory_order_relaxed); | |
| 533 | } | ||
| 534 | const GetPayloadCompletion completion{ | ||
| 535 | ✗ | task.slot, | |
| 536 | ✗ | task.client_id, | |
| 537 | ✗ | task.qp_index, | |
| 538 | ✗ | task.slot_in_qp, | |
| 539 | ✗ | task.poll_thread_id, | |
| 540 | ✗ | task.seq, | |
| 541 | task.response, | ||
| 542 | payload_written_direct, | ||
| 543 | ✗ | }; | |
| 544 | ✗ | PushGetPayloadCompletion(completion); | |
| 545 | ✗ | } | |
| 546 | } | ||
| 547 | |||
| 548 | ✗ | void CompleteResponseForSlot( | |
| 549 | int slot, | ||
| 550 | int client_id, | ||
| 551 | int qp_index, | ||
| 552 | int slot_in_qp, | ||
| 553 | const petps::RcShardServerTransport::ResponseView& response, | ||
| 554 | std::uint64_t seq, | ||
| 555 | bool profile_enabled) { | ||
| 556 | std::atomic_thread_fence(std::memory_order_release); | ||
| 557 | ✗ | const std::uint64_t complete_start_ns = profile_enabled ? NowNs() : 0; | |
| 558 | ✗ | transport_->CompleteResponse( | |
| 559 | client_id, qp_index, slot_in_qp, response, seq); | ||
| 560 | ✗ | if (profile_enabled) { | |
| 561 | ✗ | profile_.complete_response_ns.fetch_add( | |
| 562 | ✗ | NowNs() - complete_start_ns, std::memory_order_relaxed); | |
| 563 | } | ||
| 564 | ✗ | VLOG(1) << "component=rdma_rc_server event=complete shard=" << shard_id_ | |
| 565 | ✗ | << " slot=" << slot << " client_id=" << client_id | |
| 566 | ✗ | << " qp=" << qp_index << " seq=" << seq | |
| 567 | ✗ | << " status=" << response.status->status | |
| 568 | ✗ | << " response_bytes=" << response.status->response_bytes; | |
| 569 | ✗ | last_seq_[static_cast<std::size_t>(slot)] = seq; | |
| 570 | ✗ | if (GetPayloadOffloadEnabled()) { | |
| 571 | ✗ | inflight_seq_[static_cast<std::size_t>(slot)] = 0; | |
| 572 | } | ||
| 573 | ✗ | } | |
| 574 | |||
| 575 | ✗ | void CompleteResponseStatusOnlyForSlot( | |
| 576 | int slot, | ||
| 577 | int client_id, | ||
| 578 | int qp_index, | ||
| 579 | int slot_in_qp, | ||
| 580 | const petps::RcShardServerTransport::ResponseView& response, | ||
| 581 | std::uint64_t seq, | ||
| 582 | bool profile_enabled) { | ||
| 583 | std::atomic_thread_fence(std::memory_order_release); | ||
| 584 | ✗ | const std::uint64_t complete_start_ns = profile_enabled ? NowNs() : 0; | |
| 585 | ✗ | transport_->CompleteResponseStatusOnly( | |
| 586 | client_id, qp_index, slot_in_qp, response, seq); | ||
| 587 | ✗ | if (profile_enabled) { | |
| 588 | ✗ | profile_.complete_response_ns.fetch_add( | |
| 589 | ✗ | NowNs() - complete_start_ns, std::memory_order_relaxed); | |
| 590 | } | ||
| 591 | ✗ | VLOG(1) << "component=rdma_rc_server event=complete_direct shard=" | |
| 592 | ✗ | << shard_id_ << " slot=" << slot << " client_id=" << client_id | |
| 593 | ✗ | << " qp=" << qp_index << " seq=" << seq | |
| 594 | ✗ | << " status=" << response.status->status | |
| 595 | ✗ | << " response_bytes=" << response.status->response_bytes; | |
| 596 | ✗ | last_seq_[static_cast<std::size_t>(slot)] = seq; | |
| 597 | ✗ | if (GetPayloadOffloadEnabled()) { | |
| 598 | ✗ | inflight_seq_[static_cast<std::size_t>(slot)] = 0; | |
| 599 | } | ||
| 600 | ✗ | } | |
| 601 | |||
| 602 | ✗ | void DrainGetPayloadCompletions(int poll_thread_id, bool profile_enabled) { | |
| 603 | ✗ | GetPayloadCompletion completion; | |
| 604 | ✗ | while (TryPopGetPayloadCompletion(poll_thread_id, &completion)) { | |
| 605 | ✗ | if (completion.payload_written_direct) { | |
| 606 | ✗ | CompleteResponseStatusOnlyForSlot( | |
| 607 | completion.slot, | ||
| 608 | completion.client_id, | ||
| 609 | completion.qp_index, | ||
| 610 | completion.slot_in_qp, | ||
| 611 | completion.response, | ||
| 612 | completion.seq, | ||
| 613 | profile_enabled); | ||
| 614 | } else { | ||
| 615 | ✗ | CompleteResponseForSlot( | |
| 616 | completion.slot, | ||
| 617 | completion.client_id, | ||
| 618 | completion.qp_index, | ||
| 619 | completion.slot_in_qp, | ||
| 620 | completion.response, | ||
| 621 | completion.seq, | ||
| 622 | profile_enabled); | ||
| 623 | } | ||
| 624 | } | ||
| 625 | ✗ | } | |
| 626 | |||
| 627 | ✗ | bool HandleGetDirectSg( | |
| 628 | const petps::RequestDescriptor& descriptor, | ||
| 629 | base::ConstArray<std::uint64_t> keys, | ||
| 630 | petps::RcShardServerTransport::ResponseView* response, | ||
| 631 | int thread_id, | ||
| 632 | int slot_in_qp, | ||
| 633 | CachePS::FlatGetProfile* get_profile) { | ||
| 634 | ✗ | if (descriptor.response_bytes == 0 || descriptor.embedding_dim == 0) { | |
| 635 | ✗ | return false; | |
| 636 | } | ||
| 637 | ✗ | const std::size_t row_bytes = | |
| 638 | ✗ | static_cast<std::size_t>(descriptor.embedding_dim) * sizeof(float); | |
| 639 | ✗ | if (row_bytes == 0 || | |
| 640 | ✗ | descriptor.response_bytes != | |
| 641 | ✗ | descriptor.key_count * static_cast<std::uint32_t>(row_bytes)) { | |
| 642 | ✗ | return false; | |
| 643 | } | ||
| 644 | |||
| 645 | ✗ | thread_local std::vector<CachePS::DirectFixedRow> rows; | |
| 646 | ✗ | rows.clear(); | |
| 647 | const std::uint64_t direct_start_ns = | ||
| 648 | ✗ | FLAGS_rdma_rc_profile_interval_ms > 0 ? NowNs() : 0; | |
| 649 | ✗ | const bool ok = cache_ps_->GetParameterDirectFixedRows( | |
| 650 | keys, | ||
| 651 | ✗ | descriptor.key_count, | |
| 652 | ✗ | descriptor.embedding_dim, | |
| 653 | thread_id, | ||
| 654 | &rows, | ||
| 655 | get_profile); | ||
| 656 | ✗ | if (!ok || rows.size() != descriptor.key_count) { | |
| 657 | ✗ | return false; | |
| 658 | } | ||
| 659 | ✗ | std::uint64_t response_offset = 0; | |
| 660 | ✗ | std::uint64_t wr_count = 0; | |
| 661 | ✗ | for (std::size_t row = 0; row < rows.size();) { | |
| 662 | ✗ | std::array<petps::RawVerbsSge, kMaxDirectSgesPerWr> sges{}; | |
| 663 | ✗ | std::size_t sge_count = 0; | |
| 664 | ✗ | std::size_t row_count = 0; | |
| 665 | ✗ | for (; row < rows.size(); ++row) { | |
| 666 | ✗ | const auto& ref = rows[row]; | |
| 667 | ✗ | if (ref.missing || ref.data == nullptr || ref.size != row_bytes) { | |
| 668 | ✗ | return false; | |
| 669 | } | ||
| 670 | ✗ | if (sge_count > 0) { | |
| 671 | ✗ | auto& last = sges[sge_count - 1]; | |
| 672 | ✗ | const char* last_end = | |
| 673 | ✗ | static_cast<const char*>(last.data) + last.bytes; | |
| 674 | ✗ | if (last_end == ref.data) { | |
| 675 | ✗ | last.bytes += row_bytes; | |
| 676 | ✗ | ++row_count; | |
| 677 | ✗ | continue; | |
| 678 | } | ||
| 679 | } | ||
| 680 | ✗ | if (sge_count == kMaxDirectSgesPerWr) { | |
| 681 | ✗ | break; | |
| 682 | } | ||
| 683 | ✗ | sges[sge_count++] = petps::RawVerbsSge{ref.data, row_bytes}; | |
| 684 | ✗ | ++row_count; | |
| 685 | } | ||
| 686 | ✗ | const std::uint64_t bytes = | |
| 687 | static_cast<std::uint64_t>(row_count * row_bytes); | ||
| 688 | ✗ | transport_->WriteResponsePayloadSg( | |
| 689 | ✗ | descriptor.client_id, | |
| 690 | ✗ | descriptor.qp_index, | |
| 691 | slot_in_qp, | ||
| 692 | base::ConstArray<petps::RawVerbsSge>( | ||
| 693 | ✗ | sges.data(), static_cast<int>(sge_count)), | |
| 694 | response_offset, | ||
| 695 | bytes); | ||
| 696 | ✗ | response_offset += bytes; | |
| 697 | ✗ | ++wr_count; | |
| 698 | } | ||
| 699 | ✗ | response->status->status = static_cast<std::int32_t>(petps::RpcStatus::kOk); | |
| 700 | ✗ | response->status->response_bytes = | |
| 701 | ✗ | static_cast<std::uint32_t>(descriptor.response_bytes); | |
| 702 | ✗ | if (FLAGS_rdma_rc_profile_interval_ms > 0) { | |
| 703 | ✗ | profile_.get_direct_sg.fetch_add(1, std::memory_order_relaxed); | |
| 704 | ✗ | profile_.get_direct_sg_ns.fetch_add( | |
| 705 | ✗ | NowNs() - direct_start_ns, std::memory_order_relaxed); | |
| 706 | ✗ | profile_.get_direct_sg_wr.fetch_add(wr_count, std::memory_order_relaxed); | |
| 707 | ✗ | if (get_profile != nullptr) { | |
| 708 | ✗ | AccumulateFlatGetProfile(*get_profile); | |
| 709 | } | ||
| 710 | } | ||
| 711 | ✗ | return true; | |
| 712 | } | ||
| 713 | |||
| 714 | ✗ | bool HandleGet(const petps::RequestDescriptor& descriptor, | |
| 715 | const char* payload, | ||
| 716 | petps::RcShardServerTransport::ResponseView* response, | ||
| 717 | int thread_id, | ||
| 718 | int slot_in_qp) { | ||
| 719 | ✗ | if (FLAGS_rdma_rc_fake_get_mode == "status_only") { | |
| 720 | ✗ | response->status->status = | |
| 721 | static_cast<std::int32_t>(petps::RpcStatus::kOk); | ||
| 722 | ✗ | response->status->response_bytes = 0; | |
| 723 | ✗ | return false; | |
| 724 | } | ||
| 725 | ✗ | if (FLAGS_rdma_rc_fake_get_mode == "payload_memset") { | |
| 726 | ✗ | std::memset(response->payload, 0, descriptor.response_bytes); | |
| 727 | ✗ | response->status->status = | |
| 728 | static_cast<std::int32_t>(petps::RpcStatus::kOk); | ||
| 729 | ✗ | response->status->response_bytes = | |
| 730 | ✗ | static_cast<std::uint32_t>(descriptor.response_bytes); | |
| 731 | ✗ | return false; | |
| 732 | } | ||
| 733 | ✗ | if (FLAGS_rdma_rc_fake_get_mode == "index_only") { | |
| 734 | base::ConstArray<std::uint64_t> keys( | ||
| 735 | reinterpret_cast<const std::uint64_t*>(payload), | ||
| 736 | ✗ | descriptor.key_count); | |
| 737 | ✗ | CachePS::FlatGetProfile get_profile; | |
| 738 | ✗ | CachePS::FlatGetProfile* get_profile_ptr = | |
| 739 | ✗ | FLAGS_rdma_rc_profile_interval_ms > 0 ? &get_profile : nullptr; | |
| 740 | const bool ok = | ||
| 741 | ✗ | cache_ps_->ProbeParameterIndex(keys, thread_id, get_profile_ptr); | |
| 742 | ✗ | if (get_profile_ptr != nullptr) { | |
| 743 | ✗ | AccumulateFlatGetProfile(get_profile); | |
| 744 | } | ||
| 745 | ✗ | response->status->status = static_cast<std::int32_t>( | |
| 746 | ok ? petps::RpcStatus::kOk : petps::RpcStatus::kValueSizeMismatch); | ||
| 747 | ✗ | response->status->response_bytes = 0; | |
| 748 | ✗ | return false; | |
| 749 | } | ||
| 750 | ✗ | if (FLAGS_rdma_rc_fake_get_mode != "none" && | |
| 751 | ✗ | !FLAGS_rdma_rc_fake_get_mode.empty()) { | |
| 752 | ✗ | response->status->status = | |
| 753 | static_cast<std::int32_t>(petps::RpcStatus::kInvalidPayload); | ||
| 754 | ✗ | response->status->response_bytes = 0; | |
| 755 | ✗ | return false; | |
| 756 | } | ||
| 757 | |||
| 758 | base::ConstArray<std::uint64_t> keys( | ||
| 759 | ✗ | reinterpret_cast<const std::uint64_t*>(payload), descriptor.key_count); | |
| 760 | ✗ | CachePS::FlatGetProfile get_profile; | |
| 761 | ✗ | CachePS::FlatGetProfile* get_profile_ptr = | |
| 762 | ✗ | FLAGS_rdma_rc_profile_interval_ms > 0 ? &get_profile : nullptr; | |
| 763 | ✗ | if ((descriptor.flags & petps::kRcFlagGetDirectSg) != 0) { | |
| 764 | ✗ | const bool direct_ok = HandleGetDirectSg( | |
| 765 | descriptor, keys, response, thread_id, slot_in_qp, get_profile_ptr); | ||
| 766 | ✗ | if (direct_ok) { | |
| 767 | ✗ | return true; | |
| 768 | } | ||
| 769 | ✗ | if (FLAGS_rdma_rc_profile_interval_ms > 0) { | |
| 770 | ✗ | profile_.get_direct_sg_fallback.fetch_add(1, std::memory_order_relaxed); | |
| 771 | } | ||
| 772 | ✗ | if ((descriptor.flags & petps::kRcFlagGetAllowFallbackCopy) == 0) { | |
| 773 | ✗ | response->status->status = | |
| 774 | static_cast<std::int32_t>(petps::RpcStatus::kInvalidPayload); | ||
| 775 | ✗ | response->status->response_bytes = 0; | |
| 776 | ✗ | return false; | |
| 777 | } | ||
| 778 | } | ||
| 779 | ✗ | const bool ok = cache_ps_->GetParameterFlat( | |
| 780 | keys, | ||
| 781 | ✗ | reinterpret_cast<float*>(response->payload), | |
| 782 | ✗ | descriptor.key_count, | |
| 783 | ✗ | descriptor.embedding_dim, | |
| 784 | thread_id, | ||
| 785 | get_profile_ptr); | ||
| 786 | ✗ | if (get_profile_ptr != nullptr) { | |
| 787 | ✗ | AccumulateFlatGetProfile(get_profile); | |
| 788 | } | ||
| 789 | ✗ | response->status->status = static_cast<std::int32_t>( | |
| 790 | ok ? petps::RpcStatus::kOk : petps::RpcStatus::kValueSizeMismatch); | ||
| 791 | ✗ | response->status->response_bytes = | |
| 792 | ✗ | static_cast<std::uint32_t>(descriptor.response_bytes); | |
| 793 | ✗ | return false; | |
| 794 | } | ||
| 795 | |||
| 796 | ✗ | void HandlePut(const petps::RequestDescriptor& descriptor, | |
| 797 | const char* payload, | ||
| 798 | petps::RcShardServerTransport::ResponseView* response, | ||
| 799 | int thread_id) { | ||
| 800 | ✗ | const auto* reader = | |
| 801 | reinterpret_cast<const ParameterCompressReader*>(payload); | ||
| 802 | ✗ | if (!reader->Valid(static_cast<int>(descriptor.payload_bytes))) { | |
| 803 | ✗ | response->status->status = | |
| 804 | static_cast<std::int32_t>(petps::RpcStatus::kInvalidPayload); | ||
| 805 | ✗ | response->status->response_bytes = 0; | |
| 806 | ✗ | return; | |
| 807 | } | ||
| 808 | ✗ | for (int i = 0; i < reader->item_size(); ++i) { | |
| 809 | ✗ | cache_ps_->PutSingleParameter(reader->item(i), thread_id); | |
| 810 | } | ||
| 811 | ✗ | response->status->status = static_cast<std::int32_t>(petps::RpcStatus::kOk); | |
| 812 | ✗ | response->status->response_bytes = 0; | |
| 813 | } | ||
| 814 | |||
| 815 | ✗ | void HandleUpdate(const petps::RequestDescriptor& descriptor, | |
| 816 | const char* payload, | ||
| 817 | petps::RcShardServerTransport::ResponseView* response, | ||
| 818 | int thread_id) { | ||
| 819 | ✗ | const std::string_view table_name = petps::DescriptorTableName(descriptor); | |
| 820 | ✗ | if (table_name.empty()) { | |
| 821 | ✗ | response->status->status = | |
| 822 | static_cast<std::int32_t>(petps::RpcStatus::kInvalidPayload); | ||
| 823 | ✗ | response->status->response_bytes = 0; | |
| 824 | ✗ | return; | |
| 825 | } | ||
| 826 | |||
| 827 | ✗ | const auto* reader = | |
| 828 | reinterpret_cast<const ParameterCompressReader*>(payload); | ||
| 829 | ✗ | if (!reader->Valid(static_cast<int>(descriptor.payload_bytes))) { | |
| 830 | ✗ | response->status->status = | |
| 831 | static_cast<std::int32_t>(petps::RpcStatus::kInvalidPayload); | ||
| 832 | ✗ | response->status->response_bytes = 0; | |
| 833 | ✗ | return; | |
| 834 | } | ||
| 835 | |||
| 836 | ✗ | const bool ok = cache_ps_->UpdateParameter( | |
| 837 | ✗ | std::string(table_name), reader, static_cast<unsigned>(thread_id)); | |
| 838 | ✗ | response->status->status = static_cast<std::int32_t>( | |
| 839 | ok ? petps::RpcStatus::kOk : petps::RpcStatus::kInvalidPayload); | ||
| 840 | ✗ | response->status->response_bytes = 0; | |
| 841 | } | ||
| 842 | |||
| 843 | ✗ | void HandleInitTable(const petps::RequestDescriptor& descriptor, | |
| 844 | const char* payload, | ||
| 845 | petps::RcShardServerTransport::ResponseView* response) { | ||
| 846 | ✗ | const std::string_view table_name = petps::DescriptorTableName(descriptor); | |
| 847 | ✗ | if (table_name.empty() || | |
| 848 | ✗ | descriptor.payload_bytes != petps::InitTablePayloadBytes()) { | |
| 849 | ✗ | response->status->status = | |
| 850 | static_cast<std::int32_t>(petps::RpcStatus::kInvalidPayload); | ||
| 851 | ✗ | response->status->response_bytes = 0; | |
| 852 | ✗ | return; | |
| 853 | } | ||
| 854 | |||
| 855 | ✗ | std::uint64_t num_embeddings = 0; | |
| 856 | ✗ | std::uint64_t embedding_dim = 0; | |
| 857 | ✗ | std::memcpy(&num_embeddings, payload, sizeof(num_embeddings)); | |
| 858 | ✗ | std::memcpy(&embedding_dim, | |
| 859 | ✗ | payload + sizeof(num_embeddings), | |
| 860 | sizeof(embedding_dim)); | ||
| 861 | ✗ | const bool ok = cache_ps_->InitTable( | |
| 862 | ✗ | std::string(table_name), num_embeddings, embedding_dim); | |
| 863 | ✗ | response->status->status = static_cast<std::int32_t>( | |
| 864 | ok ? petps::RpcStatus::kOk : petps::RpcStatus::kInvalidPayload); | ||
| 865 | ✗ | response->status->response_bytes = 0; | |
| 866 | } | ||
| 867 | |||
| 868 | ✗ | void MaybePublishServerReady() { | |
| 869 | const int started = | ||
| 870 | ✗ | started_threads_.fetch_add(1, std::memory_order_relaxed) + 1; | |
| 871 | ✗ | if (started != thread_count_ || | |
| 872 | ✗ | ready_published_.exchange(true, std::memory_order_acq_rel)) { | |
| 873 | ✗ | return; | |
| 874 | } | ||
| 875 | ✗ | control_plane_client_.PublishServerReady(FLAGS_global_id); | |
| 876 | ✗ | LOG(INFO) << "component=rdma_control_plane event=server_ready_published" | |
| 877 | ✗ | << " server_id=" << FLAGS_global_id | |
| 878 | << " host=" << FLAGS_rdma_control_plane_host | ||
| 879 | ✗ | << " port=" << FLAGS_rdma_control_plane_port; | |
| 880 | } | ||
| 881 | |||
| 882 | ✗ | void PollingThread(int thread_id) { | |
| 883 | ✗ | BindServerCore(thread_id); | |
| 884 | ✗ | LOG(INFO) << "component=rdma_server event=polling_thread_ready thread_id=" | |
| 885 | ✗ | << thread_id; | |
| 886 | ✗ | MaybePublishServerReady(); | |
| 887 | const int coroutines_per_thread = | ||
| 888 | ✗ | std::max(1, FLAGS_rdma_rc_server_coroutines_per_thread); | |
| 889 | ✗ | LOG(INFO) << "component=rdma_rc_server event=polling_thread_mode" | |
| 890 | ✗ | << " thread_id=" << thread_id | |
| 891 | ✗ | << " coroutines_per_thread=" << coroutines_per_thread; | |
| 892 | ✗ | if (coroutines_per_thread > 1) { | |
| 893 | ✗ | RunCoroutinePollingThread(thread_id, coroutines_per_thread); | |
| 894 | ✗ | return; | |
| 895 | } | ||
| 896 | while (true) { | ||
| 897 | ✗ | const bool profile_enabled = FLAGS_rdma_rc_profile_interval_ms > 0; | |
| 898 | ✗ | const std::uint64_t poll_start_ns = profile_enabled ? NowNs() : 0; | |
| 899 | ✗ | std::uint64_t scanned_slots = 0; | |
| 900 | ✗ | std::uint64_t ready_slots = 0; | |
| 901 | ✗ | DrainGetPayloadCompletions(thread_id, profile_enabled); | |
| 902 | ✗ | ScanAssignedSlots( | |
| 903 | thread_id, | ||
| 904 | /*worker_id=*/0, | ||
| 905 | /*worker_count=*/1, | ||
| 906 | profile_enabled, | ||
| 907 | &scanned_slots, | ||
| 908 | &ready_slots); | ||
| 909 | ✗ | DrainGetPayloadCompletions(thread_id, profile_enabled); | |
| 910 | ✗ | if (profile_enabled) { | |
| 911 | ✗ | profile_.scan_rounds.fetch_add(1, std::memory_order_relaxed); | |
| 912 | ✗ | profile_.scanned_slots.fetch_add( | |
| 913 | scanned_slots, std::memory_order_relaxed); | ||
| 914 | ✗ | if (ready_slots == 0) { | |
| 915 | ✗ | profile_.empty_scan_rounds.fetch_add(1, std::memory_order_relaxed); | |
| 916 | } | ||
| 917 | ✗ | UpdateMax(&profile_.max_ready_per_round, ready_slots); | |
| 918 | ✗ | const std::uint64_t poll_loop_ns = NowNs() - poll_start_ns; | |
| 919 | ✗ | profile_.poll_loop_ns.fetch_add( | |
| 920 | poll_loop_ns, std::memory_order_relaxed); | ||
| 921 | auto& poller = | ||
| 922 | ✗ | *poller_profiles_.at(static_cast<std::size_t>(thread_id)); | |
| 923 | ✗ | poller.scan_rounds.fetch_add(1, std::memory_order_relaxed); | |
| 924 | ✗ | poller.scanned_slots.fetch_add( | |
| 925 | scanned_slots, std::memory_order_relaxed); | ||
| 926 | ✗ | poller.ready_slots.fetch_add(ready_slots, std::memory_order_relaxed); | |
| 927 | ✗ | poller.poll_loop_ns.fetch_add(poll_loop_ns, std::memory_order_relaxed); | |
| 928 | ✗ | MaybeReportProfile(thread_id); | |
| 929 | } | ||
| 930 | ✗ | std::this_thread::yield(); | |
| 931 | ✗ | } | |
| 932 | } | ||
| 933 | |||
| 934 | ✗ | bool ProcessSlot(int slot, int thread_id, bool profile_enabled) { | |
| 935 | ✗ | int client_id = -1; | |
| 936 | ✗ | int qp_index = -1; | |
| 937 | ✗ | int slot_in_qp = -1; | |
| 938 | ✗ | transport_->DecodeSlotIndex(slot, &client_id, &qp_index, &slot_in_qp); | |
| 939 | ✗ | auto* commit = transport_->RequestCommitAt(slot); | |
| 940 | ✗ | if (commit->state.load(std::memory_order_acquire) != petps::kRcSlotReady) { | |
| 941 | ✗ | if (profile_enabled) { | |
| 942 | ✗ | profile_.not_ready_slots.fetch_add(1, std::memory_order_relaxed); | |
| 943 | ✗ | poller_profiles_.at(static_cast<std::size_t>(thread_id)) | |
| 944 | ✗ | ->not_ready_slots.fetch_add(1, std::memory_order_relaxed); | |
| 945 | } | ||
| 946 | ✗ | return false; | |
| 947 | } | ||
| 948 | ✗ | const std::uint64_t seq = commit->seq.load(std::memory_order_acquire); | |
| 949 | ✗ | if (seq == 0) { | |
| 950 | ✗ | if (profile_enabled) { | |
| 951 | ✗ | profile_.zero_seq_ready.fetch_add(1, std::memory_order_relaxed); | |
| 952 | } | ||
| 953 | ✗ | return false; | |
| 954 | } | ||
| 955 | ✗ | if (seq == last_seq_[static_cast<std::size_t>(slot)]) { | |
| 956 | ✗ | if (profile_enabled) { | |
| 957 | ✗ | profile_.duplicate_seq_ready.fetch_add(1, std::memory_order_relaxed); | |
| 958 | ✗ | poller_profiles_.at(static_cast<std::size_t>(thread_id)) | |
| 959 | ✗ | ->duplicate_seq_ready.fetch_add(1, std::memory_order_relaxed); | |
| 960 | } | ||
| 961 | ✗ | return false; | |
| 962 | } | ||
| 963 | ✗ | if (GetPayloadOffloadEnabled() && | |
| 964 | ✗ | seq == inflight_seq_[static_cast<std::size_t>(slot)]) { | |
| 965 | ✗ | if (profile_enabled) { | |
| 966 | ✗ | profile_.inflight_seq_ready.fetch_add(1, std::memory_order_relaxed); | |
| 967 | ✗ | poller_profiles_.at(static_cast<std::size_t>(thread_id)) | |
| 968 | ✗ | ->inflight_seq_ready.fetch_add(1, std::memory_order_relaxed); | |
| 969 | } | ||
| 970 | ✗ | return false; | |
| 971 | } | ||
| 972 | ✗ | if (profile_enabled) { | |
| 973 | ✗ | profile_.ready_slots.fetch_add(1, std::memory_order_relaxed); | |
| 974 | } | ||
| 975 | |||
| 976 | ✗ | auto* descriptor = transport_->RequestDescriptorAt(slot); | |
| 977 | ✗ | std::string error; | |
| 978 | ✗ | if (!petps::ValidateRequestDescriptor( | |
| 979 | *descriptor, | ||
| 980 | ✗ | transport_->config().request_slot_bytes, | |
| 981 | ✗ | transport_->config().response_slot_bytes, | |
| 982 | &error)) { | ||
| 983 | ✗ | LOG(ERROR) << "component=rdma_rc_server event=invalid_descriptor" | |
| 984 | ✗ | << " shard=" << shard_id_ << " slot=" << slot | |
| 985 | ✗ | << " thread_id=" << thread_id << " seq=" << seq | |
| 986 | ✗ | << " descriptor_seq=" << descriptor->seq | |
| 987 | ✗ | << " client_id=" << descriptor->client_id | |
| 988 | ✗ | << " qp=" << descriptor->qp_index << " op=" << descriptor->op | |
| 989 | ✗ | << " key_count=" << descriptor->key_count | |
| 990 | ✗ | << " payload_bytes=" << descriptor->payload_bytes | |
| 991 | ✗ | << " response_bytes=" << descriptor->response_bytes | |
| 992 | ✗ | << " error=\"" << error << "\""; | |
| 993 | ✗ | if (profile_enabled) { | |
| 994 | ✗ | profile_.invalid_descriptor.fetch_add(1, std::memory_order_relaxed); | |
| 995 | } | ||
| 996 | ✗ | last_seq_[static_cast<std::size_t>(slot)] = seq; | |
| 997 | ✗ | commit->state.store(0, std::memory_order_release); | |
| 998 | ✗ | return true; | |
| 999 | } | ||
| 1000 | ✗ | if (descriptor->client_id != static_cast<std::uint32_t>(client_id) || | |
| 1001 | ✗ | descriptor->qp_index != static_cast<std::uint32_t>(qp_index)) { | |
| 1002 | ✗ | LOG(ERROR) << "component=rdma_rc_server event=slot_descriptor_mismatch" | |
| 1003 | ✗ | << " shard=" << shard_id_ << " slot=" << slot | |
| 1004 | ✗ | << " thread_id=" << thread_id | |
| 1005 | ✗ | << " slot_client_id=" << client_id << " slot_qp=" << qp_index | |
| 1006 | ✗ | << " descriptor_client_id=" << descriptor->client_id | |
| 1007 | ✗ | << " descriptor_qp=" << descriptor->qp_index << " seq=" << seq; | |
| 1008 | ✗ | if (profile_enabled) { | |
| 1009 | ✗ | profile_.invalid_descriptor.fetch_add(1, std::memory_order_relaxed); | |
| 1010 | } | ||
| 1011 | ✗ | last_seq_[static_cast<std::size_t>(slot)] = seq; | |
| 1012 | ✗ | commit->state.store(0, std::memory_order_release); | |
| 1013 | ✗ | return true; | |
| 1014 | } | ||
| 1015 | |||
| 1016 | auto response = | ||
| 1017 | ✗ | transport_->OpenClientResponse(client_id, qp_index, slot_in_qp); | |
| 1018 | ✗ | const char* payload = transport_->RequestPayloadAt(slot); | |
| 1019 | ✗ | VLOG(1) << "component=rdma_rc_server event=consume shard=" << shard_id_ | |
| 1020 | ✗ | << " slot=" << slot << " client_id=" << descriptor->client_id | |
| 1021 | ✗ | << " qp=" << descriptor->qp_index << " seq=" << seq << " op=" | |
| 1022 | ✗ | << descriptor->op << " key_count=" << descriptor->key_count | |
| 1023 | ✗ | << " payload_bytes=" << descriptor->payload_bytes | |
| 1024 | ✗ | << " response_bytes=" << descriptor->response_bytes; | |
| 1025 | ✗ | response.status->status = | |
| 1026 | static_cast<std::int32_t>(petps::RpcStatus::kInvalidPayload); | ||
| 1027 | ✗ | response.status->response_bytes = 0; | |
| 1028 | |||
| 1029 | ✗ | if (descriptor->shard_id != static_cast<std::uint32_t>(shard_id_)) { | |
| 1030 | ✗ | LOG(ERROR) << "component=rdma_rc_server event=wrong_shard" | |
| 1031 | ✗ | << " expected_shard=" << shard_id_ | |
| 1032 | ✗ | << " actual_shard=" << descriptor->shard_id << " slot=" << slot | |
| 1033 | ✗ | << " client_id=" << descriptor->client_id | |
| 1034 | ✗ | << " qp=" << descriptor->qp_index << " seq=" << seq << " op=" | |
| 1035 | ✗ | << descriptor->op << " key_count=" << descriptor->key_count; | |
| 1036 | ✗ | if (profile_enabled) { | |
| 1037 | ✗ | profile_.wrong_shard.fetch_add(1, std::memory_order_relaxed); | |
| 1038 | } | ||
| 1039 | ✗ | response.status->status = | |
| 1040 | static_cast<std::int32_t>(petps::RpcStatus::kWrongShard); | ||
| 1041 | ✗ | } else if (descriptor->op == | |
| 1042 | static_cast<std::uint16_t>(petps::RcOp::kGet)) { | ||
| 1043 | ✗ | if (GetPayloadOffloadEnabled()) { | |
| 1044 | const GetPayloadTask task{ | ||
| 1045 | slot, | ||
| 1046 | client_id, | ||
| 1047 | qp_index, | ||
| 1048 | slot_in_qp, | ||
| 1049 | thread_id, | ||
| 1050 | seq, | ||
| 1051 | *descriptor, | ||
| 1052 | payload, | ||
| 1053 | response, | ||
| 1054 | ✗ | }; | |
| 1055 | ✗ | if (!EnqueueGetPayloadTask(task)) { | |
| 1056 | ✗ | return false; | |
| 1057 | } | ||
| 1058 | ✗ | inflight_seq_[static_cast<std::size_t>(slot)] = seq; | |
| 1059 | ✗ | return true; | |
| 1060 | } else { | ||
| 1061 | ✗ | const std::uint64_t handle_start_ns = profile_enabled ? NowNs() : 0; | |
| 1062 | const bool payload_written_direct = | ||
| 1063 | ✗ | HandleGet(*descriptor, payload, &response, thread_id, slot_in_qp); | |
| 1064 | ✗ | if (profile_enabled) { | |
| 1065 | ✗ | profile_.handled_get.fetch_add(1, std::memory_order_relaxed); | |
| 1066 | ✗ | profile_.handle_get_ns.fetch_add( | |
| 1067 | ✗ | NowNs() - handle_start_ns, std::memory_order_relaxed); | |
| 1068 | ✗ | poller_profiles_.at(static_cast<std::size_t>(thread_id)) | |
| 1069 | ✗ | ->handled_get.fetch_add(1, std::memory_order_relaxed); | |
| 1070 | } | ||
| 1071 | ✗ | if (payload_written_direct) { | |
| 1072 | ✗ | CompleteResponseStatusOnlyForSlot( | |
| 1073 | slot, | ||
| 1074 | client_id, | ||
| 1075 | qp_index, | ||
| 1076 | slot_in_qp, | ||
| 1077 | response, | ||
| 1078 | seq, | ||
| 1079 | profile_enabled); | ||
| 1080 | ✗ | return true; | |
| 1081 | } | ||
| 1082 | } | ||
| 1083 | ✗ | } else if (descriptor->op == | |
| 1084 | static_cast<std::uint16_t>(petps::RcOp::kPut)) { | ||
| 1085 | ✗ | const std::uint64_t handle_start_ns = profile_enabled ? NowNs() : 0; | |
| 1086 | ✗ | HandlePut(*descriptor, payload, &response, thread_id); | |
| 1087 | ✗ | if (profile_enabled) { | |
| 1088 | ✗ | profile_.handled_put.fetch_add(1, std::memory_order_relaxed); | |
| 1089 | ✗ | profile_.handle_put_ns.fetch_add( | |
| 1090 | ✗ | NowNs() - handle_start_ns, std::memory_order_relaxed); | |
| 1091 | } | ||
| 1092 | ✗ | } else if (descriptor->op == | |
| 1093 | static_cast<std::uint16_t>(petps::RcOp::kUpdate)) { | ||
| 1094 | ✗ | const std::uint64_t handle_start_ns = profile_enabled ? NowNs() : 0; | |
| 1095 | ✗ | HandleUpdate(*descriptor, payload, &response, thread_id); | |
| 1096 | ✗ | if (profile_enabled) { | |
| 1097 | ✗ | profile_.handled_update.fetch_add(1, std::memory_order_relaxed); | |
| 1098 | ✗ | profile_.handle_update_ns.fetch_add( | |
| 1099 | ✗ | NowNs() - handle_start_ns, std::memory_order_relaxed); | |
| 1100 | } | ||
| 1101 | ✗ | } else if (descriptor->op == | |
| 1102 | static_cast<std::uint16_t>(petps::RcOp::kInitTable)) { | ||
| 1103 | ✗ | const std::uint64_t handle_start_ns = profile_enabled ? NowNs() : 0; | |
| 1104 | ✗ | HandleInitTable(*descriptor, payload, &response); | |
| 1105 | ✗ | if (profile_enabled) { | |
| 1106 | ✗ | profile_.handled_init.fetch_add(1, std::memory_order_relaxed); | |
| 1107 | ✗ | profile_.handle_init_ns.fetch_add( | |
| 1108 | ✗ | NowNs() - handle_start_ns, std::memory_order_relaxed); | |
| 1109 | } | ||
| 1110 | } | ||
| 1111 | |||
| 1112 | ✗ | CompleteResponseForSlot( | |
| 1113 | slot, client_id, qp_index, slot_in_qp, response, seq, profile_enabled); | ||
| 1114 | ✗ | return true; | |
| 1115 | ✗ | } | |
| 1116 | |||
| 1117 | ✗ | void ScanAssignedSlots( | |
| 1118 | int thread_id, | ||
| 1119 | int worker_id, | ||
| 1120 | int worker_count, | ||
| 1121 | bool profile_enabled, | ||
| 1122 | std::uint64_t* scanned_slots, | ||
| 1123 | std::uint64_t* ready_slots) { | ||
| 1124 | ✗ | const int qp_count = transport_->config().qps_per_client_per_shard; | |
| 1125 | ✗ | const int slots_per_qp = transport_->config().slots_per_qp; | |
| 1126 | ✗ | const int num_clients = transport_->config().num_clients; | |
| 1127 | ✗ | const int lane_slots = num_clients * slots_per_qp; | |
| 1128 | ✗ | for (int qp_index = thread_id; qp_index < qp_count; | |
| 1129 | ✗ | qp_index += thread_count_) { | |
| 1130 | ✗ | for (int lane_slot = worker_id; lane_slot < lane_slots; | |
| 1131 | ✗ | lane_slot += worker_count) { | |
| 1132 | ✗ | const int client_id = lane_slot / slots_per_qp; | |
| 1133 | ✗ | const int slot_in_qp = lane_slot % slots_per_qp; | |
| 1134 | const int slot_index = | ||
| 1135 | ✗ | transport_->SlotIndex(client_id, qp_index, slot_in_qp); | |
| 1136 | ✗ | ++(*scanned_slots); | |
| 1137 | ✗ | if (ProcessSlot(slot_index, thread_id, profile_enabled)) { | |
| 1138 | ✗ | ++(*ready_slots); | |
| 1139 | } | ||
| 1140 | } | ||
| 1141 | } | ||
| 1142 | ✗ | } | |
| 1143 | |||
| 1144 | ✗ | void CoroutineSlotScanner( | |
| 1145 | boost::coroutines2::coroutine<void>::push_type& sink, | ||
| 1146 | int thread_id, | ||
| 1147 | int worker_id, | ||
| 1148 | int worker_count) { | ||
| 1149 | while (true) { | ||
| 1150 | ✗ | const bool profile_enabled = FLAGS_rdma_rc_profile_interval_ms > 0; | |
| 1151 | ✗ | const std::uint64_t poll_start_ns = profile_enabled ? NowNs() : 0; | |
| 1152 | ✗ | std::uint64_t scanned_slots = 0; | |
| 1153 | ✗ | std::uint64_t ready_slots = 0; | |
| 1154 | ✗ | DrainGetPayloadCompletions(thread_id, profile_enabled); | |
| 1155 | ✗ | ScanAssignedSlots( | |
| 1156 | thread_id, | ||
| 1157 | worker_id, | ||
| 1158 | worker_count, | ||
| 1159 | profile_enabled, | ||
| 1160 | &scanned_slots, | ||
| 1161 | &ready_slots); | ||
| 1162 | ✗ | DrainGetPayloadCompletions(thread_id, profile_enabled); | |
| 1163 | ✗ | if (profile_enabled) { | |
| 1164 | ✗ | profile_.scan_rounds.fetch_add(1, std::memory_order_relaxed); | |
| 1165 | ✗ | profile_.scanned_slots.fetch_add( | |
| 1166 | scanned_slots, std::memory_order_relaxed); | ||
| 1167 | ✗ | if (ready_slots == 0) { | |
| 1168 | ✗ | profile_.empty_scan_rounds.fetch_add(1, std::memory_order_relaxed); | |
| 1169 | } | ||
| 1170 | ✗ | UpdateMax(&profile_.max_ready_per_round, ready_slots); | |
| 1171 | ✗ | const std::uint64_t poll_loop_ns = NowNs() - poll_start_ns; | |
| 1172 | ✗ | profile_.poll_loop_ns.fetch_add( | |
| 1173 | poll_loop_ns, std::memory_order_relaxed); | ||
| 1174 | auto& poller = | ||
| 1175 | ✗ | *poller_profiles_.at(static_cast<std::size_t>(thread_id)); | |
| 1176 | ✗ | poller.scan_rounds.fetch_add(1, std::memory_order_relaxed); | |
| 1177 | ✗ | poller.scanned_slots.fetch_add( | |
| 1178 | scanned_slots, std::memory_order_relaxed); | ||
| 1179 | ✗ | poller.ready_slots.fetch_add(ready_slots, std::memory_order_relaxed); | |
| 1180 | ✗ | poller.poll_loop_ns.fetch_add(poll_loop_ns, std::memory_order_relaxed); | |
| 1181 | } | ||
| 1182 | ✗ | sink(); | |
| 1183 | ✗ | } | |
| 1184 | } | ||
| 1185 | |||
| 1186 | ✗ | void RunCoroutinePollingThread(int thread_id, int coroutines_per_thread) { | |
| 1187 | using Coroutine = boost::coroutines2::coroutine<void>; | ||
| 1188 | ✗ | std::vector<std::unique_ptr<Coroutine::pull_type>> coroutines; | |
| 1189 | ✗ | coroutines.reserve(static_cast<std::size_t>(coroutines_per_thread)); | |
| 1190 | ✗ | for (int coroutine_id = 0; coroutine_id < coroutines_per_thread; | |
| 1191 | ++coroutine_id) { | ||
| 1192 | ✗ | coroutines.emplace_back(std::make_unique<Coroutine::pull_type>( | |
| 1193 | ✗ | [this, thread_id, coroutine_id, coroutines_per_thread]( | |
| 1194 | ✗ | Coroutine::push_type& sink) { | |
| 1195 | ✗ | CoroutineSlotScanner( | |
| 1196 | sink, thread_id, coroutine_id, coroutines_per_thread); | ||
| 1197 | ✗ | })); | |
| 1198 | } | ||
| 1199 | while (true) { | ||
| 1200 | ✗ | for (auto& coroutine : coroutines) { | |
| 1201 | ✗ | (*coroutine)(); | |
| 1202 | } | ||
| 1203 | ✗ | MaybeReportProfile(thread_id); | |
| 1204 | ✗ | std::this_thread::yield(); | |
| 1205 | ✗ | } | |
| 1206 | ✗ | } | |
| 1207 | |||
| 1208 | CachePS* cache_ps_ = nullptr; | ||
| 1209 | int thread_count_ = 1; | ||
| 1210 | int shard_id_ = 0; | ||
| 1211 | std::unique_ptr<petps::RcShardServerTransport> transport_; | ||
| 1212 | petps::RdmaControlPlaneClient control_plane_client_; | ||
| 1213 | std::vector<std::thread> threads_; | ||
| 1214 | std::vector<std::uint64_t> last_seq_; | ||
| 1215 | std::vector<std::uint64_t> inflight_seq_; | ||
| 1216 | std::vector<std::unique_ptr<PollerProfile>> poller_profiles_; | ||
| 1217 | int get_payload_worker_count_ = 0; | ||
| 1218 | std::vector<std::thread> get_payload_workers_; | ||
| 1219 | std::mutex get_payload_mu_; | ||
| 1220 | std::condition_variable get_payload_cv_; | ||
| 1221 | std::deque<GetPayloadTask> get_payload_tasks_; | ||
| 1222 | std::vector<std::deque<GetPayloadCompletion>> get_payload_completions_; | ||
| 1223 | std::atomic<int> started_threads_{0}; | ||
| 1224 | std::atomic<bool> ready_published_{false}; | ||
| 1225 | ProfileCounters profile_; | ||
| 1226 | }; | ||
| 1227 | |||
| 1228 | } // namespace | ||
| 1229 | |||
| 1230 | ✗ | int main(int argc, char* argv[]) { | |
| 1231 | ✗ | folly::init(&argc, &argv); | |
| 1232 | ✗ | if (ShouldTraceRdmaGet()) { | |
| 1233 | ✗ | std::cerr << "component=rdma_get_trace side=server event=enabled interval=" | |
| 1234 | ✗ | << RdmaGetTraceInterval() << std::endl; | |
| 1235 | } | ||
| 1236 | ✗ | xmh::Reporter::StartReportThread(); | |
| 1237 | |||
| 1238 | ✗ | base::PMMmapRegisterCenter::GetConfig().backend = | |
| 1239 | ✗ | base::PMMmapRegisterCenter::BackendFromUseDram(FLAGS_use_dram); | |
| 1240 | ✗ | base::PMMmapRegisterCenter::GetConfig().numa_id = FLAGS_numa_id; | |
| 1241 | |||
| 1242 | ✗ | base::global_socket_id = FLAGS_numa_id; | |
| 1243 | ✗ | LOG(INFO) << "set NUMA ID = " << FLAGS_numa_id; | |
| 1244 | |||
| 1245 | const std::string config_path = | ||
| 1246 | ✗ | FLAGS_config_path.empty() | |
| 1247 | ✗ | ? base::ResolveRecStoreConfigPath().string() | |
| 1248 | ✗ | : FLAGS_config_path; | |
| 1249 | ✗ | std::ifstream config_file(config_path); | |
| 1250 | ✗ | if (!config_file.is_open()) { | |
| 1251 | ✗ | LOG(FATAL) << "Cannot open config file: " << config_path; | |
| 1252 | } | ||
| 1253 | |||
| 1254 | ✗ | nlohmann::json config; | |
| 1255 | ✗ | config_file >> config; | |
| 1256 | ✗ | if (config.contains("cache_ps") && config["cache_ps"].is_object() && | |
| 1257 | ✗ | config["cache_ps"].contains("base_kv_config")) { | |
| 1258 | ✗ | NormalizeDramValuePath(&config["cache_ps"]["base_kv_config"]); | |
| 1259 | } | ||
| 1260 | ✗ | if (config.contains("distributed_client") && | |
| 1261 | ✗ | config["distributed_client"].is_object() && | |
| 1262 | ✗ | config["distributed_client"].contains("base_kv_config")) { | |
| 1263 | ✗ | NormalizeDramValuePath(&config["distributed_client"]["base_kv_config"]); | |
| 1264 | } | ||
| 1265 | ✗ | std::unique_ptr<petps::RdmaControlPlaneServer> control_plane_server; | |
| 1266 | ✗ | if (FLAGS_global_id == 0) { | |
| 1267 | ✗ | control_plane_server = std::make_unique<petps::RdmaControlPlaneServer>( | |
| 1268 | ✗ | petps::RdmaControlPlaneEndpoint{ | |
| 1269 | FLAGS_rdma_control_plane_host, | ||
| 1270 | FLAGS_rdma_control_plane_port, | ||
| 1271 | FLAGS_rdma_control_plane_timeout_ms, | ||
| 1272 | ✗ | }); | |
| 1273 | ✗ | control_plane_server->Start(); | |
| 1274 | ✗ | LOG(INFO) << "component=rdma_control_plane event=listening" | |
| 1275 | << " server_id=0" | ||
| 1276 | << " host=" << FLAGS_rdma_control_plane_host | ||
| 1277 | ✗ | << " port=" << FLAGS_rdma_control_plane_port; | |
| 1278 | } | ||
| 1279 | ✗ | auto cache_ps = std::make_unique<CachePS>(config["cache_ps"]); | |
| 1280 | ✗ | const int shard_id = ResolveShardId(config); | |
| 1281 | auto ps = std::make_unique<PetPSServer>( | ||
| 1282 | ✗ | cache_ps.get(), FLAGS_thread_num, shard_id, NamespaceToken()); | |
| 1283 | ✗ | ps->Run(); | |
| 1284 | while (true) { | ||
| 1285 | ✗ | std::this_thread::sleep_for(std::chrono::seconds(1)); | |
| 1286 | } | ||
| 1287 | return 0; | ||
| 1288 | ✗ | } | |
| 1289 |