ps/rdma/petps_client.cc
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #include "ps/rdma/petps_client.h" | ||
| 2 | |||
| 3 | #include <algorithm> | ||
| 4 | #include <cstdlib> | ||
| 5 | #include <cstring> | ||
| 6 | #include <iostream> | ||
| 7 | #include <stdexcept> | ||
| 8 | #include <thread> | ||
| 9 | |||
| 10 | #include <folly/portability/GFlags.h> | ||
| 11 | |||
| 12 | #include "ps/rdma/control_plane.h" | ||
| 13 | #include "ps/rdma/rdma_common.h" | ||
| 14 | #include "ps/rdma/rc_options.h" | ||
| 15 | |||
| 16 | DECLARE_int32(global_id); | ||
| 17 | DECLARE_int32(num_server_processes); | ||
| 18 | DECLARE_int32(num_client_processes); | ||
| 19 | DECLARE_int32(value_size); | ||
| 20 | DECLARE_int32(max_kv_num_per_request); | ||
| 21 | DEFINE_string(rdma_get_response_mode, | ||
| 22 | "direct_sg", | ||
| 23 | "RDMA GET response mode: direct_sg or staging_copy"); | ||
| 24 | |||
| 25 | namespace petps { | ||
| 26 | namespace { | ||
| 27 | |||
| 28 | using petps::Exchange; | ||
| 29 | using petps::NamespaceToken; | ||
| 30 | using petps::NowNs; | ||
| 31 | |||
| 32 | ✗ | std::size_t ComputeMaxGetKeysPerRpc() { | |
| 33 | ✗ | return GetKeysPerRpcByResponseBudget( | |
| 34 | static_cast<std::size_t>(FLAGS_value_size), | ||
| 35 | static_cast<std::size_t>(FLAGS_rdma_rc_mtu_bytes), | ||
| 36 | ✗ | static_cast<std::size_t>(FLAGS_rdma_rc_target_response_mtu)); | |
| 37 | } | ||
| 38 | |||
| 39 | ✗ | std::int32_t WaitStatus(const StatusWord* status, std::uint64_t seq) { | |
| 40 | ✗ | const auto start = std::chrono::steady_clock::now(); | |
| 41 | ✗ | int spin_iterations = 0; | |
| 42 | ✗ | while (!StatusWordDone(*status, seq)) { | |
| 43 | ✗ | if (spin_iterations < FLAGS_rdma_rc_wait_spin_iterations) { | |
| 44 | ✗ | ++spin_iterations; | |
| 45 | } else { | ||
| 46 | ✗ | spin_iterations = 0; | |
| 47 | ✗ | std::this_thread::yield(); | |
| 48 | } | ||
| 49 | ✗ | if (FLAGS_rdma_wait_timeout_ms > 0) { | |
| 50 | const auto elapsed_ms = | ||
| 51 | ✗ | std::chrono::duration_cast<std::chrono::milliseconds>( | |
| 52 | ✗ | std::chrono::steady_clock::now() - start) | |
| 53 | ✗ | .count(); | |
| 54 | ✗ | if (elapsed_ms > FLAGS_rdma_wait_timeout_ms) { | |
| 55 | ✗ | throw std::runtime_error("RC write RPC wait timeout"); | |
| 56 | } | ||
| 57 | } | ||
| 58 | } | ||
| 59 | ✗ | return status->status; | |
| 60 | } | ||
| 61 | |||
| 62 | ✗ | void FillBaseDescriptor( | |
| 63 | RequestDescriptor* descriptor, | ||
| 64 | std::uint64_t seq, | ||
| 65 | std::size_t key_count, | ||
| 66 | const RcClientQpView& view, | ||
| 67 | std::uint32_t shard_id, | ||
| 68 | std::uint32_t client_id) { | ||
| 69 | ✗ | *descriptor = RequestDescriptor{}; | |
| 70 | ✗ | descriptor->seq = seq; | |
| 71 | ✗ | descriptor->shard_id = shard_id; | |
| 72 | ✗ | descriptor->client_id = client_id; | |
| 73 | ✗ | descriptor->qp_index = static_cast<std::uint32_t>(view.qp_index); | |
| 74 | ✗ | descriptor->key_count = static_cast<std::uint32_t>(key_count); | |
| 75 | ✗ | descriptor->value_size = static_cast<std::uint32_t>(FLAGS_value_size); | |
| 76 | ✗ | descriptor->embedding_dim = | |
| 77 | ✗ | static_cast<std::uint32_t>(FLAGS_value_size / sizeof(float)); | |
| 78 | ✗ | descriptor->payload_offset = | |
| 79 | ✗ | static_cast<std::uint32_t>(Align64(sizeof(RequestDescriptor))); | |
| 80 | ✗ | descriptor->client_response_addr = | |
| 81 | ✗ | reinterpret_cast<std::uint64_t>(view.response_payload); | |
| 82 | ✗ | descriptor->client_status_addr = reinterpret_cast<std::uint64_t>(view.status); | |
| 83 | ✗ | } | |
| 84 | |||
| 85 | } // namespace | ||
| 86 | |||
| 87 | ✗ | PetPSClient::PetPSClient(const std::string& host, int port, int shard) | |
| 88 | ✗ | : PetPSClient(host, port, shard, -1) {} | |
| 89 | |||
| 90 | ✗ | PetPSClient::PetPSClient( | |
| 91 | ✗ | const std::string& host, int port, int shard, int logical_client_id) | |
| 92 | : BaseParameterClient(host, port, shard), | ||
| 93 | ✗ | namespace_token_(NamespaceToken()), | |
| 94 | ✗ | explicit_client_id_(logical_client_id) {} | |
| 95 | |||
| 96 | ✗ | PetPSClient::~PetPSClient() = default; | |
| 97 | |||
| 98 | ✗ | void PetPSClient::Barrier(const std::string&, int) {} | |
| 99 | |||
| 100 | ✗ | void PetPSClient::InitializeTransport() { | |
| 101 | ✗ | if (transport_ != nullptr) { | |
| 102 | ✗ | return; | |
| 103 | } | ||
| 104 | ✗ | client_id_ = | |
| 105 | ✗ | explicit_client_id_ >= 0 | |
| 106 | ✗ | ? explicit_client_id_ | |
| 107 | ✗ | : (FLAGS_rdma_rc_client_id_base >= 0 | |
| 108 | ✗ | ? FLAGS_rdma_rc_client_id_base | |
| 109 | ✗ | : FLAGS_global_id - FLAGS_num_server_processes); | |
| 110 | ✗ | if (client_id_ < 0) { | |
| 111 | ✗ | throw std::runtime_error("invalid RC write logical client_id"); | |
| 112 | } | ||
| 113 | ✗ | const int logical_num_clients = | |
| 114 | ✗ | FLAGS_rdma_rc_num_logical_clients >= 0 | |
| 115 | ✗ | ? FLAGS_rdma_rc_num_logical_clients | |
| 116 | : FLAGS_num_client_processes; | ||
| 117 | ✗ | if (client_id_ >= logical_num_clients) { | |
| 118 | ✗ | throw std::runtime_error("RC write logical client_id out of range"); | |
| 119 | } | ||
| 120 | ✗ | config_.shard_id = shard_; | |
| 121 | ✗ | config_.client_id = client_id_; | |
| 122 | ✗ | config_.num_clients = logical_num_clients; | |
| 123 | ✗ | config_.qps_per_client_per_shard = FLAGS_rdma_rc_qps_per_client_per_shard; | |
| 124 | ✗ | config_.slots_per_qp = FLAGS_rdma_rc_slots_per_qp; | |
| 125 | ✗ | config_.request_slot_bytes = | |
| 126 | ✗ | static_cast<std::size_t>(FLAGS_rdma_rc_request_slot_bytes); | |
| 127 | ✗ | config_.response_slot_bytes = | |
| 128 | ✗ | static_cast<std::size_t>(FLAGS_rdma_rc_response_slot_bytes); | |
| 129 | ✗ | config_.control_plane_host = FLAGS_rdma_control_plane_host; | |
| 130 | ✗ | config_.control_plane_port = FLAGS_rdma_control_plane_port; | |
| 131 | ✗ | config_.control_plane_timeout_ms = FLAGS_rdma_control_plane_timeout_ms; | |
| 132 | ✗ | config_.namespace_token = namespace_token_; | |
| 133 | |||
| 134 | ✗ | transport_ = std::make_unique<RcShardClientTransport>(config_); | |
| 135 | RdmaControlPlaneClient control_plane({ | ||
| 136 | ✗ | config_.control_plane_host, | |
| 137 | ✗ | config_.control_plane_port, | |
| 138 | ✗ | config_.control_plane_timeout_ms, | |
| 139 | ✗ | }); | |
| 140 | ✗ | control_plane.WaitServer(shard_, config_.control_plane_timeout_ms); | |
| 141 | ✗ | qps_.clear(); | |
| 142 | ✗ | qps_.reserve(static_cast<std::size_t>(config_.qps_per_client_per_shard)); | |
| 143 | ✗ | for (int qp = 0; qp < config_.qps_per_client_per_shard; ++qp) { | |
| 144 | ✗ | QpContext context; | |
| 145 | ✗ | context.qp_index = qp; | |
| 146 | ✗ | context.slots.reserve(static_cast<std::size_t>(config_.slots_per_qp)); | |
| 147 | ✗ | for (int slot_in_qp = 0; slot_in_qp < config_.slots_per_qp; ++slot_in_qp) { | |
| 148 | ✗ | context.slots.push_back( | |
| 149 | ✗ | SlotContext{transport_->OpenSlot(qp, slot_in_qp), 1, false}); | |
| 150 | } | ||
| 151 | ✗ | qps_.push_back(std::move(context)); | |
| 152 | ✗ | } | |
| 153 | ✗ | } | |
| 154 | |||
| 155 | ✗ | void PetPSClient::InitThread() { | |
| 156 | ✗ | std::lock_guard<std::mutex> guard(mu_); | |
| 157 | ✗ | InitializeTransport(); | |
| 158 | ✗ | thread_initialized_ = true; | |
| 159 | ✗ | } | |
| 160 | |||
| 161 | ✗ | std::size_t PetPSClient::ResponseBufferBytes(std::size_t key_count) const { | |
| 162 | ✗ | return GetResponseBytes( | |
| 163 | key_count, static_cast<std::size_t>(FLAGS_value_size)) + | ||
| 164 | ✗ | sizeof(std::int32_t); | |
| 165 | } | ||
| 166 | |||
| 167 | ✗ | void* PetPSClient::GetReceiveBuffer(size_t size) { | |
| 168 | ✗ | std::lock_guard<std::mutex> guard(mu_); | |
| 169 | ✗ | receive_buffers_.emplace_back(size, 0); | |
| 170 | ✗ | return receive_buffers_.back().data(); | |
| 171 | ✗ | } | |
| 172 | |||
| 173 | ✗ | const float* PetPSClient::BorrowGetResultPayload( | |
| 174 | int rpc_id, | ||
| 175 | std::size_t* key_count, | ||
| 176 | std::size_t* response_bytes, | ||
| 177 | std::int32_t* status_code) { | ||
| 178 | ✗ | PendingRpc pending; | |
| 179 | { | ||
| 180 | ✗ | std::lock_guard<std::mutex> guard(mu_); | |
| 181 | ✗ | if (!PendingRpcLocked(rpc_id, &pending)) { | |
| 182 | ✗ | return nullptr; | |
| 183 | } | ||
| 184 | ✗ | } | |
| 185 | |||
| 186 | ✗ | auto& slot = SlotAt(pending.qp_index, pending.slot_in_qp); | |
| 187 | ✗ | const bool profile_enabled = FLAGS_rdma_rc_profile_interval_ms > 0; | |
| 188 | ✗ | const std::uint64_t wait_start_ns = profile_enabled ? NowNs() : 0; | |
| 189 | ✗ | const std::int32_t rc_status = WaitStatus(slot.view.status, pending.seq); | |
| 190 | ✗ | if (profile_enabled) { | |
| 191 | ✗ | profile_.wait_rpc_count.fetch_add(1, std::memory_order_relaxed); | |
| 192 | ✗ | profile_.wait_status_ns.fetch_add( | |
| 193 | ✗ | NowNs() - wait_start_ns, std::memory_order_relaxed); | |
| 194 | } | ||
| 195 | |||
| 196 | ✗ | const std::size_t actual_response_bytes = std::min<std::size_t>( | |
| 197 | ✗ | slot.view.status->response_bytes, pending.response_bytes); | |
| 198 | ✗ | if (key_count != nullptr) { | |
| 199 | ✗ | *key_count = pending.key_count; | |
| 200 | } | ||
| 201 | ✗ | if (response_bytes != nullptr) { | |
| 202 | ✗ | *response_bytes = actual_response_bytes; | |
| 203 | } | ||
| 204 | ✗ | if (status_code != nullptr) { | |
| 205 | ✗ | *status_code = rc_status; | |
| 206 | } | ||
| 207 | ✗ | if (pending.recv_buffer != nullptr) { | |
| 208 | ✗ | auto* user_status = FixedSlotStatusWord( | |
| 209 | ✗ | pending.recv_buffer, pending.key_count, FLAGS_value_size); | |
| 210 | ✗ | *user_status = rc_status; | |
| 211 | } | ||
| 212 | ✗ | MaybeReportProfile(); | |
| 213 | ✗ | return reinterpret_cast<const float*>(slot.view.response_payload); | |
| 214 | } | ||
| 215 | |||
| 216 | ✗ | PetPSClient::SlotHandle PetPSClient::AcquireIdleSlot() { | |
| 217 | ✗ | if (FLAGS_rdma_rc_profile_interval_ms > 0) { | |
| 218 | ✗ | profile_.acquire_qp_count.fetch_add(1, std::memory_order_relaxed); | |
| 219 | } | ||
| 220 | ✗ | for (std::size_t qp_index = 0; qp_index < qps_.size(); ++qp_index) { | |
| 221 | ✗ | auto& qp = qps_[qp_index]; | |
| 222 | ✗ | for (std::size_t slot_in_qp = 0; slot_in_qp < qp.slots.size(); | |
| 223 | ++slot_in_qp) { | ||
| 224 | ✗ | if (!qp.slots[slot_in_qp].busy) { | |
| 225 | ✗ | qp.slots[slot_in_qp].busy = true; | |
| 226 | return SlotHandle{ | ||
| 227 | static_cast<int>(qp_index), | ||
| 228 | static_cast<int>(slot_in_qp), | ||
| 229 | ✗ | }; | |
| 230 | } | ||
| 231 | } | ||
| 232 | } | ||
| 233 | ✗ | if (FLAGS_rdma_rc_profile_interval_ms > 0) { | |
| 234 | ✗ | profile_.acquire_qp_failures.fetch_add(1, std::memory_order_relaxed); | |
| 235 | } | ||
| 236 | ✗ | throw std::runtime_error("no idle RC write slot available"); | |
| 237 | } | ||
| 238 | |||
| 239 | ✗ | PetPSClient::SlotContext& PetPSClient::SlotAt(int qp_index, int slot_in_qp) { | |
| 240 | ✗ | auto& qp = qps_.at(static_cast<std::size_t>(qp_index)); | |
| 241 | ✗ | return qp.slots.at(static_cast<std::size_t>(slot_in_qp)); | |
| 242 | } | ||
| 243 | |||
| 244 | const PetPSClient::SlotContext& | ||
| 245 | ✗ | PetPSClient::SlotAt(int qp_index, int slot_in_qp) const { | |
| 246 | ✗ | const auto& qp = qps_.at(static_cast<std::size_t>(qp_index)); | |
| 247 | ✗ | return qp.slots.at(static_cast<std::size_t>(slot_in_qp)); | |
| 248 | } | ||
| 249 | |||
| 250 | ✗ | void PetPSClient::EnsureThreadInitializedLocked() const { | |
| 251 | ✗ | if (!thread_initialized_) { | |
| 252 | ✗ | throw std::runtime_error("PetPSClient::InitThread must be called first"); | |
| 253 | } | ||
| 254 | ✗ | } | |
| 255 | |||
| 256 | ✗ | bool PetPSClient::PendingRpcLocked(int rpc_id, PendingRpc* pending) const { | |
| 257 | ✗ | const auto it = pending_rpcs_.find(rpc_id); | |
| 258 | ✗ | if (it == pending_rpcs_.end()) { | |
| 259 | ✗ | return false; | |
| 260 | } | ||
| 261 | ✗ | if (pending != nullptr) { | |
| 262 | ✗ | *pending = it->second; | |
| 263 | } | ||
| 264 | ✗ | return true; | |
| 265 | } | ||
| 266 | |||
| 267 | ✗ | bool PetPSClient::RequestPayloadFitsSlot(std::size_t payload_bytes) const { | |
| 268 | ✗ | return Align64(sizeof(RequestDescriptor)) + payload_bytes + | |
| 269 | ✗ | Align64(sizeof(CommitWord)) <= | |
| 270 | ✗ | config_.request_slot_bytes; | |
| 271 | } | ||
| 272 | |||
| 273 | ✗ | float* PetPSClient::AllocateStatusReceiveBufferLocked() { | |
| 274 | ✗ | receive_buffers_.emplace_back(sizeof(std::int32_t), 0); | |
| 275 | ✗ | return reinterpret_cast<float*>(receive_buffers_.back().data()); | |
| 276 | } | ||
| 277 | |||
| 278 | ✗ | void PetPSClient::MaybeReportProfile() { | |
| 279 | ✗ | if (FLAGS_rdma_rc_profile_interval_ms <= 0) { | |
| 280 | ✗ | return; | |
| 281 | } | ||
| 282 | ✗ | const std::uint64_t now = NowNs(); | |
| 283 | ✗ | const std::uint64_t interval = | |
| 284 | ✗ | static_cast<std::uint64_t>(FLAGS_rdma_rc_profile_interval_ms) * 1000000; | |
| 285 | std::uint64_t expected = | ||
| 286 | ✗ | profile_.next_report_ns.load(std::memory_order_relaxed); | |
| 287 | ✗ | if (expected == 0) { | |
| 288 | ✗ | profile_.next_report_ns.compare_exchange_strong( | |
| 289 | expected, now + interval, std::memory_order_relaxed); | ||
| 290 | ✗ | return; | |
| 291 | } | ||
| 292 | ✗ | if (now < expected || | |
| 293 | ✗ | !profile_.next_report_ns.compare_exchange_strong( | |
| 294 | expected, now + interval, std::memory_order_relaxed)) { | ||
| 295 | ✗ | return; | |
| 296 | } | ||
| 297 | |||
| 298 | ✗ | const std::uint64_t submit_count = Exchange(&profile_.submit_rpc_count); | |
| 299 | ✗ | const std::uint64_t wait_count = Exchange(&profile_.wait_rpc_count); | |
| 300 | ✗ | const std::uint64_t revoke_count = Exchange(&profile_.revoke_rpc_count); | |
| 301 | ✗ | const std::uint64_t submit_ns = Exchange(&profile_.submit_request_ns); | |
| 302 | ✗ | const std::uint64_t wait_ns = Exchange(&profile_.wait_status_ns); | |
| 303 | ✗ | const std::uint64_t copy_ns = Exchange(&profile_.copy_response_ns); | |
| 304 | ✗ | const std::uint64_t revoke_ns = Exchange(&profile_.revoke_resource_ns); | |
| 305 | ✗ | const std::uint64_t pending_samples = Exchange(&profile_.pending_rpc_samples); | |
| 306 | ✗ | const std::uint64_t pending_sum = Exchange(&profile_.pending_rpc_sum); | |
| 307 | std::cout | ||
| 308 | << "component=rdma_rc_client_profile" | ||
| 309 | ✗ | << " shard=" << shard_ << " client_id=" << client_id_ | |
| 310 | ✗ | << " submit_count=" << submit_count << " wait_count=" << wait_count | |
| 311 | ✗ | << " revoke_count=" << revoke_count | |
| 312 | ✗ | << " acquire_qp_count=" << Exchange(&profile_.acquire_qp_count) | |
| 313 | ✗ | << " acquire_qp_failures=" << Exchange(&profile_.acquire_qp_failures) | |
| 314 | ✗ | << " submit_avg_ns=" << (submit_count == 0 ? 0 : submit_ns / submit_count) | |
| 315 | ✗ | << " wait_status_avg_ns=" << (wait_count == 0 ? 0 : wait_ns / wait_count) | |
| 316 | ✗ | << " copy_response_avg_ns=" | |
| 317 | ✗ | << (wait_count == 0 ? 0 : copy_ns / wait_count) | |
| 318 | ✗ | << " copied_bytes=" << Exchange(&profile_.response_bytes_copied) | |
| 319 | ✗ | << " revoke_avg_ns=" << (revoke_count == 0 ? 0 : revoke_ns / revoke_count) | |
| 320 | ✗ | << " pending_rpc_peak=" << Exchange(&profile_.pending_rpc_peak) | |
| 321 | ✗ | << " pending_rpc_avg=" | |
| 322 | ✗ | << (pending_samples == 0 ? 0 : pending_sum / pending_samples) | |
| 323 | ✗ | << " pending_rpc_last=" | |
| 324 | ✗ | << profile_.pending_rpc_last.load(std::memory_order_relaxed) << std::endl; | |
| 325 | } | ||
| 326 | |||
| 327 | ✗ | void PetPSClient::FillGetDescriptor( | |
| 328 | RequestDescriptor* descriptor, | ||
| 329 | std::uint64_t seq, | ||
| 330 | std::size_t key_count, | ||
| 331 | std::size_t response_bytes, | ||
| 332 | const RcClientQpView& view) const { | ||
| 333 | ✗ | FillBaseDescriptor( | |
| 334 | descriptor, | ||
| 335 | seq, | ||
| 336 | key_count, | ||
| 337 | view, | ||
| 338 | ✗ | static_cast<std::uint32_t>(shard_), | |
| 339 | ✗ | static_cast<std::uint32_t>(client_id_)); | |
| 340 | ✗ | descriptor->op = static_cast<std::uint16_t>(RcOp::kGet); | |
| 341 | ✗ | descriptor->payload_bytes = | |
| 342 | ✗ | static_cast<std::uint32_t>(GetRequestBytes(key_count)); | |
| 343 | ✗ | descriptor->response_bytes = static_cast<std::uint32_t>(response_bytes); | |
| 344 | ✗ | if (FLAGS_rdma_get_response_mode == "direct_sg") { | |
| 345 | ✗ | descriptor->flags |= kRcFlagGetDirectSg | kRcFlagGetAllowFallbackCopy; | |
| 346 | ✗ | } else if (FLAGS_rdma_get_response_mode != "staging_copy") { | |
| 347 | ✗ | LOG(FATAL) << "unsupported --rdma_get_response_mode=" | |
| 348 | ✗ | << FLAGS_rdma_get_response_mode; | |
| 349 | } | ||
| 350 | ✗ | } | |
| 351 | |||
| 352 | ✗ | void PetPSClient::FillPutDescriptor( | |
| 353 | RequestDescriptor* descriptor, | ||
| 354 | std::uint64_t seq, | ||
| 355 | std::size_t key_count, | ||
| 356 | std::size_t payload_bytes, | ||
| 357 | const RcClientQpView& view) const { | ||
| 358 | ✗ | FillBaseDescriptor( | |
| 359 | descriptor, | ||
| 360 | seq, | ||
| 361 | key_count, | ||
| 362 | view, | ||
| 363 | ✗ | static_cast<std::uint32_t>(shard_), | |
| 364 | ✗ | static_cast<std::uint32_t>(client_id_)); | |
| 365 | ✗ | descriptor->op = static_cast<std::uint16_t>(RcOp::kPut); | |
| 366 | ✗ | descriptor->payload_bytes = static_cast<std::uint32_t>(payload_bytes); | |
| 367 | ✗ | descriptor->response_bytes = 0; | |
| 368 | ✗ | } | |
| 369 | |||
| 370 | ✗ | void PetPSClient::FillUpdateDescriptor( | |
| 371 | RequestDescriptor* descriptor, | ||
| 372 | std::uint64_t seq, | ||
| 373 | std::size_t key_count, | ||
| 374 | std::size_t payload_bytes, | ||
| 375 | const std::string& table_name, | ||
| 376 | const RcClientQpView& view) const { | ||
| 377 | ✗ | FillPutDescriptor(descriptor, seq, key_count, payload_bytes, view); | |
| 378 | ✗ | descriptor->op = static_cast<std::uint16_t>(RcOp::kUpdate); | |
| 379 | ✗ | if (!CopyTableName(table_name, &descriptor->table_name)) { | |
| 380 | ✗ | throw std::runtime_error("UPDATE table name too long"); | |
| 381 | } | ||
| 382 | ✗ | } | |
| 383 | |||
| 384 | ✗ | void PetPSClient::FillInitTableDescriptor( | |
| 385 | RequestDescriptor* descriptor, | ||
| 386 | std::uint64_t seq, | ||
| 387 | const std::string& table_name, | ||
| 388 | const RcClientQpView& view) const { | ||
| 389 | ✗ | FillPutDescriptor( | |
| 390 | descriptor, seq, /*key_count=*/0, InitTablePayloadBytes(), view); | ||
| 391 | ✗ | descriptor->op = static_cast<std::uint16_t>(RcOp::kInitTable); | |
| 392 | ✗ | if (!CopyTableName(table_name, &descriptor->table_name)) { | |
| 393 | ✗ | throw std::runtime_error("INIT table name too long"); | |
| 394 | } | ||
| 395 | ✗ | } | |
| 396 | |||
| 397 | ✗ | int PetPSClient::SubmitRpcLocked( | |
| 398 | SlotContext* slot, | ||
| 399 | const RequestDescriptor& descriptor, | ||
| 400 | const void* payload, | ||
| 401 | std::size_t payload_bytes, | ||
| 402 | float* recv_buffer, | ||
| 403 | std::size_t key_count, | ||
| 404 | std::size_t response_bytes, | ||
| 405 | bool is_async) { | ||
| 406 | ✗ | if (slot == nullptr) { | |
| 407 | ✗ | throw std::runtime_error("slot context is null"); | |
| 408 | } | ||
| 409 | ✗ | ResetStatusWord(slot->view.status, descriptor.seq); | |
| 410 | ✗ | const bool profile_enabled = FLAGS_rdma_rc_profile_interval_ms > 0; | |
| 411 | ✗ | const std::uint64_t submit_start_ns = profile_enabled ? NowNs() : 0; | |
| 412 | ✗ | transport_->SubmitRequest(slot->view, descriptor, payload, payload_bytes); | |
| 413 | ✗ | if (profile_enabled) { | |
| 414 | ✗ | profile_.submit_rpc_count.fetch_add(1, std::memory_order_relaxed); | |
| 415 | ✗ | profile_.submit_request_ns.fetch_add( | |
| 416 | ✗ | NowNs() - submit_start_ns, std::memory_order_relaxed); | |
| 417 | } | ||
| 418 | ✗ | VLOG(1) << "component=rdma_rc_client event=submit shard=" << shard_ | |
| 419 | ✗ | << " client_id=" << client_id_ << " qp=" << slot->view.qp_index | |
| 420 | ✗ | << " slot=" << slot->view.slot_index << " seq=" << descriptor.seq | |
| 421 | ✗ | << " op=" << descriptor.op << " key_count=" << key_count | |
| 422 | ✗ | << " payload_bytes=" << payload_bytes | |
| 423 | ✗ | << " response_bytes=" << response_bytes; | |
| 424 | |||
| 425 | ✗ | const int rpc_id = next_rpc_id_.fetch_add(1); | |
| 426 | ✗ | pending_rpcs_.emplace( | |
| 427 | rpc_id, | ||
| 428 | ✗ | PendingRpc{ | |
| 429 | ✗ | slot->view.qp_index, | |
| 430 | ✗ | slot->view.slot_in_qp, | |
| 431 | ✗ | slot->view.slot_index, | |
| 432 | ✗ | descriptor.seq, | |
| 433 | recv_buffer, | ||
| 434 | key_count, | ||
| 435 | response_bytes, | ||
| 436 | }); | ||
| 437 | ✗ | if (profile_enabled) { | |
| 438 | ✗ | const std::uint64_t pending_size = pending_rpcs_.size(); | |
| 439 | ✗ | profile_.pending_rpc_samples.fetch_add(1, std::memory_order_relaxed); | |
| 440 | ✗ | profile_.pending_rpc_sum.fetch_add(pending_size, std::memory_order_relaxed); | |
| 441 | ✗ | profile_.pending_rpc_last.store(pending_size, std::memory_order_relaxed); | |
| 442 | std::uint64_t peak = | ||
| 443 | ✗ | profile_.pending_rpc_peak.load(std::memory_order_relaxed); | |
| 444 | ✗ | while (pending_size > peak && | |
| 445 | ✗ | !profile_.pending_rpc_peak.compare_exchange_weak( | |
| 446 | peak, pending_size, std::memory_order_relaxed)) { | ||
| 447 | } | ||
| 448 | ✗ | MaybeReportProfile(); | |
| 449 | } | ||
| 450 | ✗ | if (!is_async) { | |
| 451 | ✗ | WaitRPCFinish(rpc_id); | |
| 452 | } | ||
| 453 | ✗ | return rpc_id; | |
| 454 | } | ||
| 455 | |||
| 456 | ✗ | int PetPSClient::GetParameter(base::ConstArray<uint64_t> keys, | |
| 457 | std::vector<std::vector<float>>* values) { | ||
| 458 | ✗ | values->clear(); | |
| 459 | ✗ | if (keys.Size() == 0) { | |
| 460 | ✗ | return 0; | |
| 461 | } | ||
| 462 | ✗ | const int embedding_dim = FLAGS_value_size / sizeof(float); | |
| 463 | ✗ | std::vector<float> flat(keys.Size() * embedding_dim + 1, 0.0f); | |
| 464 | ✗ | const int rpc_id = GetParameter(keys, flat.data(), false, 0); | |
| 465 | const auto* status = | ||
| 466 | ✗ | FixedSlotStatusWord(flat.data(), keys.Size(), FLAGS_value_size); | |
| 467 | ✗ | if (*status != static_cast<std::int32_t>(RpcStatus::kOk)) { | |
| 468 | ✗ | RevokeRPCResource(rpc_id); | |
| 469 | ✗ | return -1; | |
| 470 | } | ||
| 471 | ✗ | CopyFlatRowsToVectors( | |
| 472 | ✗ | flat.data(), | |
| 473 | ✗ | keys.Size(), | |
| 474 | static_cast<std::size_t>(embedding_dim), | ||
| 475 | values); | ||
| 476 | ✗ | RevokeRPCResource(rpc_id); | |
| 477 | ✗ | return 0; | |
| 478 | ✗ | } | |
| 479 | |||
| 480 | ✗ | int PetPSClient::GetParameter( | |
| 481 | base::ConstArray<uint64_t> keys, float* values, bool isAsync, int) { | ||
| 482 | ✗ | if (keys.Size() == 0) { | |
| 483 | ✗ | auto* status = | |
| 484 | reinterpret_cast<std::int32_t*>(reinterpret_cast<char*>(values)); | ||
| 485 | ✗ | *status = static_cast<std::int32_t>(RpcStatus::kOk); | |
| 486 | ✗ | return 0; | |
| 487 | } | ||
| 488 | ✗ | int rpc_id = 0; | |
| 489 | { | ||
| 490 | ✗ | std::lock_guard<std::mutex> guard(mu_); | |
| 491 | ✗ | EnsureThreadInitializedLocked(); | |
| 492 | ✗ | if (keys.Size() > ComputeMaxGetKeysPerRpc()) { | |
| 493 | ✗ | throw std::runtime_error( | |
| 494 | ✗ | "single-shard GET batch exceeds RC response budget"); | |
| 495 | } | ||
| 496 | |||
| 497 | ✗ | const SlotHandle slot_handle = AcquireIdleSlot(); | |
| 498 | ✗ | auto& slot = SlotAt(slot_handle.qp_index, slot_handle.slot_in_qp); | |
| 499 | ✗ | RequestDescriptor descriptor; | |
| 500 | ✗ | const std::size_t response_bytes = GetResponseBytes( | |
| 501 | ✗ | keys.Size(), static_cast<std::size_t>(FLAGS_value_size)); | |
| 502 | ✗ | FillGetDescriptor( | |
| 503 | ✗ | &descriptor, slot.next_seq++, keys.Size(), response_bytes, slot.view); | |
| 504 | ✗ | if (descriptor.payload_bytes > | |
| 505 | ✗ | PutPayloadBudget(config_.request_slot_bytes)) { | |
| 506 | ✗ | slot.busy = false; | |
| 507 | ✗ | throw std::runtime_error("GET request exceeds RC request slot"); | |
| 508 | } | ||
| 509 | ✗ | rpc_id = SubmitRpcLocked( | |
| 510 | &slot, | ||
| 511 | descriptor, | ||
| 512 | ✗ | keys.Data(), | |
| 513 | ✗ | descriptor.payload_bytes, | |
| 514 | values, | ||
| 515 | ✗ | keys.Size(), | |
| 516 | response_bytes, | ||
| 517 | true); | ||
| 518 | ✗ | } | |
| 519 | ✗ | if (!isAsync) { | |
| 520 | ✗ | WaitRPCFinish(rpc_id); | |
| 521 | } | ||
| 522 | ✗ | return rpc_id; | |
| 523 | } | ||
| 524 | |||
| 525 | ✗ | bool PetPSClient::QueryRPCFinished(int rpc_id) { | |
| 526 | ✗ | std::lock_guard<std::mutex> guard(mu_); | |
| 527 | ✗ | PendingRpc pending; | |
| 528 | ✗ | if (!PendingRpcLocked(rpc_id, &pending)) { | |
| 529 | ✗ | return true; | |
| 530 | } | ||
| 531 | ✗ | const auto& slot = SlotAt(pending.qp_index, pending.slot_in_qp); | |
| 532 | ✗ | return StatusWordDone(*slot.view.status, pending.seq); | |
| 533 | ✗ | } | |
| 534 | |||
| 535 | ✗ | void PetPSClient::WaitRPCFinish(int rpc_id) { | |
| 536 | ✗ | PendingRpc pending; | |
| 537 | { | ||
| 538 | ✗ | std::lock_guard<std::mutex> guard(mu_); | |
| 539 | ✗ | if (!PendingRpcLocked(rpc_id, &pending)) { | |
| 540 | ✗ | return; | |
| 541 | } | ||
| 542 | ✗ | } | |
| 543 | |||
| 544 | ✗ | auto& slot = SlotAt(pending.qp_index, pending.slot_in_qp); | |
| 545 | ✗ | const bool profile_enabled = FLAGS_rdma_rc_profile_interval_ms > 0; | |
| 546 | ✗ | const std::uint64_t wait_start_ns = profile_enabled ? NowNs() : 0; | |
| 547 | ✗ | const std::int32_t status_code = WaitStatus(slot.view.status, pending.seq); | |
| 548 | ✗ | if (profile_enabled) { | |
| 549 | ✗ | profile_.wait_rpc_count.fetch_add(1, std::memory_order_relaxed); | |
| 550 | ✗ | profile_.wait_status_ns.fetch_add( | |
| 551 | ✗ | NowNs() - wait_start_ns, std::memory_order_relaxed); | |
| 552 | } | ||
| 553 | ✗ | VLOG(1) << "component=rdma_rc_client event=done shard=" << shard_ | |
| 554 | ✗ | << " client_id=" << client_id_ << " qp=" << pending.qp_index | |
| 555 | ✗ | << " slot=" << pending.slot_index << " seq=" << pending.seq | |
| 556 | ✗ | << " status=" << status_code | |
| 557 | ✗ | << " response_bytes=" << pending.response_bytes; | |
| 558 | ✗ | const std::size_t actual_response_bytes = std::min<std::size_t>( | |
| 559 | ✗ | slot.view.status->response_bytes, pending.response_bytes); | |
| 560 | ✗ | if (actual_response_bytes > 0 && !FLAGS_rdma_rc_skip_client_copy) { | |
| 561 | ✗ | const std::uint64_t copy_start_ns = profile_enabled ? NowNs() : 0; | |
| 562 | ✗ | std::memcpy( | |
| 563 | ✗ | pending.recv_buffer, slot.view.response_payload, actual_response_bytes); | |
| 564 | ✗ | if (profile_enabled) { | |
| 565 | ✗ | profile_.copy_response_ns.fetch_add( | |
| 566 | ✗ | NowNs() - copy_start_ns, std::memory_order_relaxed); | |
| 567 | ✗ | profile_.response_bytes_copied.fetch_add( | |
| 568 | actual_response_bytes, std::memory_order_relaxed); | ||
| 569 | } | ||
| 570 | } | ||
| 571 | ✗ | auto* user_status = FixedSlotStatusWord( | |
| 572 | ✗ | pending.recv_buffer, pending.key_count, FLAGS_value_size); | |
| 573 | ✗ | *user_status = status_code; | |
| 574 | ✗ | MaybeReportProfile(); | |
| 575 | } | ||
| 576 | |||
| 577 | ✗ | void PetPSClient::RevokeRPCResource(int rpc_id) { | |
| 578 | ✗ | std::lock_guard<std::mutex> guard(mu_); | |
| 579 | ✗ | const auto it = pending_rpcs_.find(rpc_id); | |
| 580 | ✗ | if (it == pending_rpcs_.end()) { | |
| 581 | ✗ | return; | |
| 582 | } | ||
| 583 | ✗ | const bool profile_enabled = FLAGS_rdma_rc_profile_interval_ms > 0; | |
| 584 | ✗ | const std::uint64_t revoke_start_ns = profile_enabled ? NowNs() : 0; | |
| 585 | ✗ | auto& slot = SlotAt(it->second.qp_index, it->second.slot_in_qp); | |
| 586 | ✗ | transport_->ClearRequestSlot(slot.view); | |
| 587 | ✗ | slot.busy = false; | |
| 588 | ✗ | pending_rpcs_.erase(it); | |
| 589 | ✗ | if (profile_enabled) { | |
| 590 | ✗ | const std::uint64_t pending_size = pending_rpcs_.size(); | |
| 591 | ✗ | profile_.pending_rpc_samples.fetch_add(1, std::memory_order_relaxed); | |
| 592 | ✗ | profile_.pending_rpc_sum.fetch_add(pending_size, std::memory_order_relaxed); | |
| 593 | ✗ | profile_.pending_rpc_last.store(pending_size, std::memory_order_relaxed); | |
| 594 | ✗ | profile_.revoke_rpc_count.fetch_add(1, std::memory_order_relaxed); | |
| 595 | ✗ | profile_.revoke_resource_ns.fetch_add( | |
| 596 | ✗ | NowNs() - revoke_start_ns, std::memory_order_relaxed); | |
| 597 | ✗ | MaybeReportProfile(); | |
| 598 | } | ||
| 599 | ✗ | } | |
| 600 | |||
| 601 | ✗ | int PetPSClient::PutParameter(const std::vector<uint64_t>& keys, | |
| 602 | const std::vector<std::vector<float>>& values) { | ||
| 603 | ✗ | if (keys.size() != values.size()) { | |
| 604 | ✗ | return -1; | |
| 605 | } | ||
| 606 | ✗ | if (keys.empty()) { | |
| 607 | ✗ | return 0; | |
| 608 | } | ||
| 609 | |||
| 610 | ✗ | std::size_t begin = 0; | |
| 611 | ✗ | while (begin < keys.size()) { | |
| 612 | std::size_t end = | ||
| 613 | ✗ | std::min(begin + static_cast<std::size_t>(FLAGS_max_kv_num_per_request), | |
| 614 | ✗ | keys.size()); | |
| 615 | std::vector<std::uint64_t> key_slice( | ||
| 616 | ✗ | keys.begin() + begin, keys.begin() + end); | |
| 617 | std::vector<std::vector<float>> value_slice( | ||
| 618 | ✗ | values.begin() + begin, values.begin() + end); | |
| 619 | |||
| 620 | ✗ | std::string payload; | |
| 621 | ✗ | std::string error; | |
| 622 | const std::size_t payload_bytes = | ||
| 623 | ✗ | PutPayloadBytes(key_slice, value_slice, &payload, &error); | |
| 624 | ✗ | if (payload_bytes == 0 && !key_slice.empty()) { | |
| 625 | ✗ | throw std::runtime_error("RC PUT payload build failed: " + error); | |
| 626 | } | ||
| 627 | |||
| 628 | ✗ | float* recv = nullptr; | |
| 629 | ✗ | int rpc_id = 0; | |
| 630 | { | ||
| 631 | ✗ | std::lock_guard<std::mutex> guard(mu_); | |
| 632 | ✗ | EnsureThreadInitializedLocked(); | |
| 633 | ✗ | const SlotHandle slot_handle = AcquireIdleSlot(); | |
| 634 | ✗ | auto& slot = SlotAt(slot_handle.qp_index, slot_handle.slot_in_qp); | |
| 635 | ✗ | RequestDescriptor descriptor; | |
| 636 | ✗ | FillPutDescriptor( | |
| 637 | &descriptor, | ||
| 638 | ✗ | slot.next_seq++, | |
| 639 | key_slice.size(), | ||
| 640 | payload_bytes, | ||
| 641 | ✗ | slot.view); | |
| 642 | ✗ | if (!RequestPayloadFitsSlot(payload_bytes)) { | |
| 643 | ✗ | slot.busy = false; | |
| 644 | ✗ | throw std::runtime_error("PUT request exceeds RC request slot"); | |
| 645 | } | ||
| 646 | ✗ | recv = AllocateStatusReceiveBufferLocked(); | |
| 647 | ✗ | rpc_id = SubmitRpcLocked( | |
| 648 | ✗ | &slot, descriptor, payload.data(), payload_bytes, recv, 0, 0, true); | |
| 649 | ✗ | } | |
| 650 | ✗ | WaitRPCFinish(rpc_id); | |
| 651 | ✗ | const auto* status = reinterpret_cast<const std::int32_t*>(recv); | |
| 652 | ✗ | RevokeRPCResource(rpc_id); | |
| 653 | ✗ | if (*status != static_cast<std::int32_t>(RpcStatus::kOk)) { | |
| 654 | ✗ | return -1; | |
| 655 | } | ||
| 656 | ✗ | begin = end; | |
| 657 | ✗ | } | |
| 658 | |||
| 659 | ✗ | return 0; | |
| 660 | } | ||
| 661 | |||
| 662 | ✗ | int PetPSClient::InitEmbeddingTable(const std::string& table_name, | |
| 663 | std::uint64_t num_embeddings, | ||
| 664 | std::uint64_t embedding_dim) { | ||
| 665 | const std::array<std::uint64_t, 2> payload_words = { | ||
| 666 | num_embeddings, | ||
| 667 | embedding_dim, | ||
| 668 | ✗ | }; | |
| 669 | |||
| 670 | ✗ | float* recv = nullptr; | |
| 671 | ✗ | int rpc_id = 0; | |
| 672 | { | ||
| 673 | ✗ | std::lock_guard<std::mutex> guard(mu_); | |
| 674 | ✗ | EnsureThreadInitializedLocked(); | |
| 675 | ✗ | const SlotHandle slot_handle = AcquireIdleSlot(); | |
| 676 | ✗ | auto& slot = SlotAt(slot_handle.qp_index, slot_handle.slot_in_qp); | |
| 677 | ✗ | RequestDescriptor descriptor; | |
| 678 | ✗ | FillInitTableDescriptor( | |
| 679 | ✗ | &descriptor, slot.next_seq++, table_name, slot.view); | |
| 680 | ✗ | if (!RequestPayloadFitsSlot(descriptor.payload_bytes)) { | |
| 681 | ✗ | slot.busy = false; | |
| 682 | ✗ | throw std::runtime_error("INIT request exceeds RC request slot"); | |
| 683 | } | ||
| 684 | ✗ | recv = AllocateStatusReceiveBufferLocked(); | |
| 685 | ✗ | rpc_id = SubmitRpcLocked( | |
| 686 | &slot, | ||
| 687 | descriptor, | ||
| 688 | ✗ | payload_words.data(), | |
| 689 | ✗ | descriptor.payload_bytes, | |
| 690 | recv, | ||
| 691 | 0, | ||
| 692 | 0, | ||
| 693 | true); | ||
| 694 | ✗ | } | |
| 695 | |||
| 696 | ✗ | WaitRPCFinish(rpc_id); | |
| 697 | ✗ | const auto* status = reinterpret_cast<const std::int32_t*>(recv); | |
| 698 | ✗ | RevokeRPCResource(rpc_id); | |
| 699 | ✗ | return (*status == static_cast<std::int32_t>(RpcStatus::kOk)) ? 0 : -1; | |
| 700 | } | ||
| 701 | |||
| 702 | ✗ | int PetPSClient::UpdateParameter(const std::string& table_name, | |
| 703 | base::ConstArray<uint64_t> keys, | ||
| 704 | const std::vector<std::vector<float>>* grads) { | ||
| 705 | ✗ | if (keys.Size() == 0) { | |
| 706 | ✗ | return 0; | |
| 707 | } | ||
| 708 | ✗ | if (grads == nullptr) { | |
| 709 | ✗ | return -1; | |
| 710 | } | ||
| 711 | ✗ | if (keys.Size() != grads->size()) { | |
| 712 | ✗ | return -1; | |
| 713 | } | ||
| 714 | |||
| 715 | ✗ | std::size_t begin = 0; | |
| 716 | ✗ | const std::size_t total_keys = static_cast<std::size_t>(keys.Size()); | |
| 717 | ✗ | while (begin < total_keys) { | |
| 718 | const std::size_t end = | ||
| 719 | ✗ | std::min(begin + static_cast<std::size_t>(FLAGS_max_kv_num_per_request), | |
| 720 | ✗ | total_keys); | |
| 721 | std::vector<std::uint64_t> key_slice( | ||
| 722 | ✗ | keys.Data() + begin, keys.Data() + end); | |
| 723 | std::vector<std::vector<float>> grad_slice( | ||
| 724 | ✗ | grads->begin() + begin, grads->begin() + end); | |
| 725 | |||
| 726 | ✗ | std::string payload; | |
| 727 | ✗ | std::string error; | |
| 728 | const std::size_t payload_bytes = | ||
| 729 | ✗ | UpdatePayloadBytes(key_slice, grad_slice, &payload, &error); | |
| 730 | ✗ | if (payload_bytes == 0 && !key_slice.empty()) { | |
| 731 | ✗ | throw std::runtime_error("RC UPDATE payload build failed: " + error); | |
| 732 | } | ||
| 733 | |||
| 734 | ✗ | float* recv = nullptr; | |
| 735 | ✗ | int rpc_id = 0; | |
| 736 | { | ||
| 737 | ✗ | std::lock_guard<std::mutex> guard(mu_); | |
| 738 | ✗ | EnsureThreadInitializedLocked(); | |
| 739 | |||
| 740 | ✗ | const SlotHandle slot_handle = AcquireIdleSlot(); | |
| 741 | ✗ | auto& slot = SlotAt(slot_handle.qp_index, slot_handle.slot_in_qp); | |
| 742 | ✗ | RequestDescriptor descriptor; | |
| 743 | ✗ | FillUpdateDescriptor( | |
| 744 | &descriptor, | ||
| 745 | ✗ | slot.next_seq++, | |
| 746 | key_slice.size(), | ||
| 747 | payload_bytes, | ||
| 748 | table_name, | ||
| 749 | ✗ | slot.view); | |
| 750 | ✗ | if (!RequestPayloadFitsSlot(payload_bytes)) { | |
| 751 | ✗ | slot.busy = false; | |
| 752 | ✗ | throw std::runtime_error("UPDATE request exceeds RC request slot"); | |
| 753 | } | ||
| 754 | ✗ | recv = AllocateStatusReceiveBufferLocked(); | |
| 755 | ✗ | rpc_id = SubmitRpcLocked( | |
| 756 | ✗ | &slot, descriptor, payload.data(), payload_bytes, recv, 0, 0, true); | |
| 757 | ✗ | } | |
| 758 | |||
| 759 | ✗ | WaitRPCFinish(rpc_id); | |
| 760 | ✗ | const auto* status = reinterpret_cast<const std::int32_t*>(recv); | |
| 761 | ✗ | RevokeRPCResource(rpc_id); | |
| 762 | ✗ | if (*status != static_cast<std::int32_t>(RpcStatus::kOk)) { | |
| 763 | ✗ | return -1; | |
| 764 | } | ||
| 765 | ✗ | begin = end; | |
| 766 | ✗ | } | |
| 767 | |||
| 768 | ✗ | return 0; | |
| 769 | } | ||
| 770 | |||
| 771 | ✗ | int PetPSClient::FakePutParameter(base::ConstArray<uint64_t> keys, | |
| 772 | float* values) { | ||
| 773 | ✗ | const int embedding_dim = FLAGS_value_size / sizeof(float); | |
| 774 | ✗ | std::vector<std::vector<float>> rows; | |
| 775 | ✗ | rows.reserve(keys.Size()); | |
| 776 | ✗ | for (int i = 0; i < keys.Size(); ++i) { | |
| 777 | ✗ | rows.emplace_back( | |
| 778 | ✗ | values + i * embedding_dim, values + (i + 1) * embedding_dim); | |
| 779 | } | ||
| 780 | ✗ | return PutParameter(keys.ToVector(), rows); | |
| 781 | ✗ | } | |
| 782 | |||
| 783 | } // namespace petps | ||
| 784 |