ps/rdma/raw_verbs_transport.cc
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #include "ps/rdma/raw_verbs_transport.h" | ||
| 2 | |||
| 3 | #include <arpa/inet.h> | ||
| 4 | |||
| 5 | #include <chrono> | ||
| 6 | #include <cstdlib> | ||
| 7 | #include <cstring> | ||
| 8 | #include <limits> | ||
| 9 | #include <stdexcept> | ||
| 10 | #include <string> | ||
| 11 | #include <thread> | ||
| 12 | |||
| 13 | #include "ps/rdma/control_plane.h" | ||
| 14 | |||
| 15 | namespace petps { | ||
| 16 | namespace { | ||
| 17 | |||
| 18 | constexpr int kRawVerbsPort = 1; | ||
| 19 | constexpr int kRawVerbsGidIndex = 1; | ||
| 20 | constexpr std::uint32_t kRawVerbsPsn = 3185; | ||
| 21 | constexpr int kRawVerbsCqDepth = 4096; | ||
| 22 | constexpr int kRawVerbsRecvDepth = 1024; | ||
| 23 | |||
| 24 | ✗ | std::string IbvError(const char* op) { return std::string(op) + " failed"; } | |
| 25 | |||
| 26 | ✗ | std::string QpCreateError(const RawVerbsConfig& config, int node) { | |
| 27 | return "ibv_create_qp failed: likely insufficient RDMA QP resources " | ||
| 28 | ✗ | "(global_id=" + | |
| 29 | ✗ | std::to_string(config.global_id) + | |
| 30 | ✗ | ", local_lane=" + std::to_string(config.local_lane) + | |
| 31 | ✗ | ", remote_lane=" + std::to_string(config.remote_lane) + | |
| 32 | ✗ | ", node=" + std::to_string(node) + | |
| 33 | ✗ | ", num_servers=" + std::to_string(config.num_servers) + | |
| 34 | ✗ | ", num_clients=" + std::to_string(config.num_clients) + | |
| 35 | ✗ | "). Reduce --client-count or --qps-per-client-per-shard."; | |
| 36 | } | ||
| 37 | |||
| 38 | ✗ | ibv_context* OpenDeviceForNuma(int numa_id) { | |
| 39 | ✗ | int device_count = 0; | |
| 40 | ✗ | ibv_device** devices = ibv_get_device_list(&device_count); | |
| 41 | ✗ | if (devices == nullptr || device_count == 0) { | |
| 42 | ✗ | throw std::runtime_error("no RDMA devices found"); | |
| 43 | } | ||
| 44 | ✗ | const int device_index = SelectRawVerbsDeviceIndex(numa_id, device_count); | |
| 45 | ✗ | ibv_context* context = ibv_open_device(devices[device_index]); | |
| 46 | ✗ | ibv_free_device_list(devices); | |
| 47 | ✗ | if (context == nullptr) { | |
| 48 | ✗ | throw std::runtime_error("ibv_open_device failed"); | |
| 49 | } | ||
| 50 | ✗ | return context; | |
| 51 | } | ||
| 52 | |||
| 53 | ✗ | void ModifyQpToInit(ibv_qp* qp) { | |
| 54 | ✗ | ibv_qp_attr attr{}; | |
| 55 | ✗ | attr.qp_state = IBV_QPS_INIT; | |
| 56 | ✗ | attr.port_num = kRawVerbsPort; | |
| 57 | ✗ | attr.pkey_index = 0; | |
| 58 | ✗ | attr.qp_access_flags = IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_WRITE | | |
| 59 | IBV_ACCESS_REMOTE_ATOMIC; | ||
| 60 | ✗ | const int flags = | |
| 61 | IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS; | ||
| 62 | ✗ | if (ibv_modify_qp(qp, &attr, flags) != 0) { | |
| 63 | ✗ | throw std::runtime_error(IbvError("ibv_modify_qp INIT")); | |
| 64 | } | ||
| 65 | ✗ | } | |
| 66 | |||
| 67 | ✗ | void FillAhAttr(ibv_ah_attr* ah_attr, | |
| 68 | std::uint16_t remote_lid, | ||
| 69 | const std::uint8_t* remote_gid) { | ||
| 70 | ✗ | std::memset(ah_attr, 0, sizeof(*ah_attr)); | |
| 71 | ✗ | ah_attr->dlid = remote_lid; | |
| 72 | ✗ | ah_attr->sl = 0; | |
| 73 | ✗ | ah_attr->src_path_bits = 0; | |
| 74 | ✗ | ah_attr->port_num = kRawVerbsPort; | |
| 75 | ✗ | if (remote_gid != nullptr) { | |
| 76 | ✗ | ah_attr->is_global = 1; | |
| 77 | ✗ | std::memcpy(&ah_attr->grh.dgid, remote_gid, 16); | |
| 78 | ✗ | ah_attr->grh.sgid_index = kRawVerbsGidIndex; | |
| 79 | ✗ | ah_attr->grh.hop_limit = 1; | |
| 80 | } | ||
| 81 | ✗ | } | |
| 82 | |||
| 83 | ✗ | void ModifyQpToRtr(ibv_qp* qp, const RawVerbsNodeMeta& remote) { | |
| 84 | ✗ | ibv_qp_attr attr{}; | |
| 85 | ✗ | attr.qp_state = IBV_QPS_RTR; | |
| 86 | ✗ | attr.path_mtu = IBV_MTU_4096; | |
| 87 | ✗ | attr.dest_qp_num = remote.qpn; | |
| 88 | ✗ | attr.rq_psn = remote.psn; | |
| 89 | ✗ | attr.max_dest_rd_atomic = 16; | |
| 90 | ✗ | attr.min_rnr_timer = 12; | |
| 91 | ✗ | FillAhAttr(&attr.ah_attr, remote.lid, remote.gid); | |
| 92 | ✗ | const int flags = | |
| 93 | IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | | ||
| 94 | IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER; | ||
| 95 | ✗ | if (ibv_modify_qp(qp, &attr, flags) != 0) { | |
| 96 | ✗ | throw std::runtime_error(IbvError("ibv_modify_qp RTR")); | |
| 97 | } | ||
| 98 | ✗ | } | |
| 99 | |||
| 100 | ✗ | void ModifyQpToRts(ibv_qp* qp) { | |
| 101 | ✗ | ibv_qp_attr attr{}; | |
| 102 | ✗ | attr.qp_state = IBV_QPS_RTS; | |
| 103 | ✗ | attr.sq_psn = kRawVerbsPsn; | |
| 104 | ✗ | attr.timeout = 14; | |
| 105 | ✗ | attr.retry_cnt = 7; | |
| 106 | ✗ | attr.rnr_retry = 7; | |
| 107 | ✗ | attr.max_rd_atomic = 16; | |
| 108 | ✗ | const int flags = | |
| 109 | IBV_QP_STATE | IBV_QP_SQ_PSN | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | | ||
| 110 | IBV_QP_RNR_RETRY | IBV_QP_MAX_QP_RD_ATOMIC; | ||
| 111 | ✗ | if (ibv_modify_qp(qp, &attr, flags) != 0) { | |
| 112 | ✗ | throw std::runtime_error(IbvError("ibv_modify_qp RTS")); | |
| 113 | } | ||
| 114 | ✗ | } | |
| 115 | |||
| 116 | } // namespace | ||
| 117 | |||
| 118 | struct RawVerbsTransport::Impl { | ||
| 119 | ✗ | explicit Impl(const RawVerbsConfig& c) | |
| 120 | ✗ | : config(c), allocator(c.local_region_bytes, c.allocation_start_offset) {} | |
| 121 | |||
| 122 | RawVerbsConfig config; | ||
| 123 | ibv_context* context = nullptr; | ||
| 124 | ibv_pd* pd = nullptr; | ||
| 125 | ibv_cq* cq = nullptr; | ||
| 126 | ibv_mr* local_mr = nullptr; | ||
| 127 | std::vector<ibv_mr*> extra_mrs; | ||
| 128 | void* local_base = nullptr; | ||
| 129 | bool owns_local_base = false; | ||
| 130 | std::size_t local_bytes = 0; | ||
| 131 | RawVerbsRegionAllocator allocator; | ||
| 132 | std::vector<RawVerbsNodeMeta> metas; | ||
| 133 | std::vector<RawVerbsRemoteMemory> remotes; | ||
| 134 | std::vector<ibv_qp*> qps; | ||
| 135 | std::vector<std::uint32_t> max_inline_data_per_node; | ||
| 136 | ibv_wc wc_batch[kRawVerbsPollBatchSize] = {}; | ||
| 137 | RawVerbsCompletionBatchCursor batch_cursor; | ||
| 138 | }; | ||
| 139 | |||
| 140 | ✗ | RawVerbsTransport::RawVerbsTransport(const RawVerbsConfig& config) | |
| 141 | ✗ | : impl_(std::make_unique<Impl>(config)) { | |
| 142 | ✗ | impl_->context = OpenDeviceForNuma(config.numa_id); | |
| 143 | ✗ | impl_->pd = ibv_alloc_pd(impl_->context); | |
| 144 | ✗ | if (impl_->pd == nullptr) { | |
| 145 | ✗ | throw std::runtime_error("ibv_alloc_pd failed"); | |
| 146 | } | ||
| 147 | ✗ | impl_->cq = | |
| 148 | ✗ | ibv_create_cq(impl_->context, kRawVerbsCqDepth, nullptr, nullptr, 0); | |
| 149 | ✗ | if (impl_->cq == nullptr) { | |
| 150 | ✗ | throw std::runtime_error("ibv_create_cq failed"); | |
| 151 | } | ||
| 152 | |||
| 153 | ✗ | impl_->local_bytes = config.local_region_bytes; | |
| 154 | ✗ | impl_->local_base = reinterpret_cast<void*>(config.local_base_addr); | |
| 155 | ✗ | if (impl_->local_base == nullptr) { | |
| 156 | ✗ | const int rc = posix_memalign(&impl_->local_base, 4096, impl_->local_bytes); | |
| 157 | ✗ | if (rc != 0) { | |
| 158 | ✗ | throw std::runtime_error("posix_memalign failed for raw verbs region"); | |
| 159 | } | ||
| 160 | ✗ | impl_->owns_local_base = true; | |
| 161 | } | ||
| 162 | ✗ | impl_->local_mr = ibv_reg_mr( | |
| 163 | impl_->pd, | ||
| 164 | impl_->local_base, | ||
| 165 | impl_->local_bytes, | ||
| 166 | IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ | | ||
| 167 | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_ATOMIC); | ||
| 168 | ✗ | if (impl_->local_mr == nullptr) { | |
| 169 | ✗ | throw std::runtime_error("ibv_reg_mr failed"); | |
| 170 | } | ||
| 171 | ✗ | if (config.reserved_region_bytes != 0) { | |
| 172 | ✗ | impl_->allocator.SetReservedRegion( | |
| 173 | ✗ | {config.reserved_region_offset, config.reserved_region_bytes}); | |
| 174 | } | ||
| 175 | |||
| 176 | ✗ | const int node_count = config.num_servers + config.num_clients; | |
| 177 | ✗ | impl_->qps.resize(static_cast<std::size_t>(node_count), nullptr); | |
| 178 | ✗ | impl_->max_inline_data_per_node.assign( | |
| 179 | ✗ | static_cast<std::size_t>(node_count), 0); | |
| 180 | ✗ | for (int node = 0; node < node_count; ++node) { | |
| 181 | ✗ | if (!ShouldRawVerbsConnectToNode(config, node)) { | |
| 182 | ✗ | continue; | |
| 183 | } | ||
| 184 | ✗ | ibv_qp_init_attr init_attr{}; | |
| 185 | ✗ | init_attr.send_cq = impl_->cq; | |
| 186 | ✗ | init_attr.recv_cq = impl_->cq; | |
| 187 | ✗ | init_attr.qp_type = IBV_QPT_RC; | |
| 188 | ✗ | init_attr.cap.max_send_wr = 1024; | |
| 189 | ✗ | init_attr.cap.max_recv_wr = kRawVerbsRecvDepth; | |
| 190 | ✗ | init_attr.cap.max_send_sge = 32; | |
| 191 | ✗ | init_attr.cap.max_recv_sge = 1; | |
| 192 | ✗ | init_attr.cap.max_inline_data = config.max_inline_data; | |
| 193 | ✗ | ibv_qp* qp = ibv_create_qp(impl_->pd, &init_attr); | |
| 194 | ✗ | if (qp == nullptr) { | |
| 195 | ✗ | throw std::runtime_error(QpCreateError(config, node)); | |
| 196 | } | ||
| 197 | ✗ | impl_->qps[static_cast<std::size_t>(node)] = qp; | |
| 198 | ✗ | impl_->max_inline_data_per_node[static_cast<std::size_t>(node)] = | |
| 199 | ✗ | init_attr.cap.max_inline_data; | |
| 200 | } | ||
| 201 | ✗ | } | |
| 202 | |||
| 203 | ✗ | RawVerbsTransport::~RawVerbsTransport() { | |
| 204 | ✗ | if (!impl_) { | |
| 205 | ✗ | return; | |
| 206 | } | ||
| 207 | ✗ | for (ibv_qp* qp : impl_->qps) { | |
| 208 | ✗ | if (qp != nullptr) { | |
| 209 | ✗ | ibv_destroy_qp(qp); | |
| 210 | } | ||
| 211 | } | ||
| 212 | ✗ | if (impl_->local_mr != nullptr) { | |
| 213 | ✗ | ibv_dereg_mr(impl_->local_mr); | |
| 214 | } | ||
| 215 | ✗ | for (ibv_mr* mr : impl_->extra_mrs) { | |
| 216 | ✗ | if (mr != nullptr) { | |
| 217 | ✗ | ibv_dereg_mr(mr); | |
| 218 | } | ||
| 219 | } | ||
| 220 | ✗ | if (impl_->cq != nullptr) { | |
| 221 | ✗ | ibv_destroy_cq(impl_->cq); | |
| 222 | } | ||
| 223 | ✗ | if (impl_->pd != nullptr) { | |
| 224 | ✗ | ibv_dealloc_pd(impl_->pd); | |
| 225 | } | ||
| 226 | ✗ | if (impl_->context != nullptr) { | |
| 227 | ✗ | ibv_close_device(impl_->context); | |
| 228 | } | ||
| 229 | ✗ | if (impl_->owns_local_base && impl_->local_base != nullptr) { | |
| 230 | ✗ | free(impl_->local_base); | |
| 231 | } | ||
| 232 | ✗ | } | |
| 233 | |||
| 234 | ✗ | void RawVerbsTransport::RegisterThread() {} | |
| 235 | |||
| 236 | namespace { | ||
| 237 | ✗ | bool MrContains(ibv_mr* mr, const void* ptr, std::size_t bytes) { | |
| 238 | ✗ | if (mr == nullptr) { | |
| 239 | ✗ | return false; | |
| 240 | } | ||
| 241 | ✗ | const auto begin = reinterpret_cast<std::uintptr_t>(ptr); | |
| 242 | ✗ | const auto end = begin + bytes; | |
| 243 | ✗ | const auto mr_begin = reinterpret_cast<std::uintptr_t>(mr->addr); | |
| 244 | ✗ | const auto mr_end = mr_begin + static_cast<std::uintptr_t>(mr->length); | |
| 245 | ✗ | return begin >= mr_begin && end >= begin && end <= mr_end; | |
| 246 | } | ||
| 247 | } // namespace | ||
| 248 | |||
| 249 | ibv_mr* | ||
| 250 | ✗ | RawVerbsTransport::FindLocalMr(const void* ptr, std::size_t bytes) const { | |
| 251 | ✗ | if (impl_->local_mr != nullptr && MrContains(impl_->local_mr, ptr, bytes)) { | |
| 252 | ✗ | return impl_->local_mr; | |
| 253 | } | ||
| 254 | ✗ | for (ibv_mr* mr : impl_->extra_mrs) { | |
| 255 | ✗ | if (MrContains(mr, ptr, bytes)) { | |
| 256 | ✗ | return mr; | |
| 257 | } | ||
| 258 | } | ||
| 259 | ✗ | return nullptr; | |
| 260 | } | ||
| 261 | |||
| 262 | ✗ | void RawVerbsTransport::RegisterMemoryRegion(void* base, std::size_t bytes) { | |
| 263 | ✗ | if (base == nullptr || bytes == 0) { | |
| 264 | ✗ | return; | |
| 265 | } | ||
| 266 | ✗ | if (FindLocalMr(base, bytes) != nullptr) { | |
| 267 | ✗ | return; | |
| 268 | } | ||
| 269 | ✗ | ibv_mr* mr = ibv_reg_mr( | |
| 270 | impl_->pd, | ||
| 271 | base, | ||
| 272 | bytes, | ||
| 273 | IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ | | ||
| 274 | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_ATOMIC); | ||
| 275 | ✗ | if (mr == nullptr) { | |
| 276 | ✗ | throw std::runtime_error("ibv_reg_mr failed for extra raw verbs region"); | |
| 277 | } | ||
| 278 | ✗ | impl_->extra_mrs.push_back(mr); | |
| 279 | } | ||
| 280 | |||
| 281 | ✗ | void* RawVerbsTransport::AllocateRegistered(std::size_t bytes) { | |
| 282 | ✗ | const std::uint64_t offset = impl_->allocator.Allocate(bytes); | |
| 283 | ✗ | return static_cast<char*>(impl_->local_base) + offset; | |
| 284 | } | ||
| 285 | |||
| 286 | ✗ | std::uint64_t RawVerbsTransport::SaveAllocationState() const { | |
| 287 | ✗ | return impl_->allocator.Checkpoint(); | |
| 288 | } | ||
| 289 | |||
| 290 | ✗ | void RawVerbsTransport::RestoreAllocationState(std::uint64_t checkpoint) { | |
| 291 | ✗ | impl_->allocator.Restore(checkpoint); | |
| 292 | ✗ | } | |
| 293 | |||
| 294 | ✗ | GlobalAddress RawVerbsTransport::LocalAddress(void* ptr) const { | |
| 295 | ✗ | const auto base = reinterpret_cast<std::uintptr_t>(impl_->local_base); | |
| 296 | ✗ | const auto addr = reinterpret_cast<std::uintptr_t>(ptr); | |
| 297 | ✗ | if (addr < base || addr >= base + impl_->local_bytes) { | |
| 298 | ✗ | throw std::runtime_error("pointer outside raw verbs local region"); | |
| 299 | } | ||
| 300 | return GlobalAddress{ | ||
| 301 | ✗ | static_cast<std::uint16_t>(impl_->config.global_id), | |
| 302 | ✗ | static_cast<std::uint64_t>(addr - base), | |
| 303 | ✗ | }; | |
| 304 | } | ||
| 305 | |||
| 306 | ✗ | void* RawVerbsTransport::LocalPointer(GlobalAddress address) const { | |
| 307 | ✗ | if (address.nodeID != impl_->config.global_id) { | |
| 308 | ✗ | throw std::runtime_error( | |
| 309 | ✗ | "raw verbs local pointer requested for remote node"); | |
| 310 | } | ||
| 311 | ✗ | if (address.offset >= impl_->local_bytes) { | |
| 312 | ✗ | throw std::runtime_error("raw verbs local offset outside region"); | |
| 313 | } | ||
| 314 | ✗ | return static_cast<char*>(impl_->local_base) + address.offset; | |
| 315 | } | ||
| 316 | |||
| 317 | ✗ | RawVerbsNodeMeta RawVerbsTransport::LocalMeta() const { | |
| 318 | ✗ | ibv_port_attr port_attr{}; | |
| 319 | ✗ | if (ibv_query_port(impl_->context, kRawVerbsPort, &port_attr) != 0) { | |
| 320 | ✗ | throw std::runtime_error("ibv_query_port failed"); | |
| 321 | } | ||
| 322 | ✗ | ibv_gid gid{}; | |
| 323 | ✗ | if (ibv_query_gid(impl_->context, kRawVerbsPort, kRawVerbsGidIndex, &gid) != | |
| 324 | 0) { | ||
| 325 | ✗ | throw std::runtime_error("ibv_query_gid failed"); | |
| 326 | } | ||
| 327 | ✗ | RawVerbsNodeMeta meta{}; | |
| 328 | ✗ | meta.node_id = static_cast<std::uint16_t>(impl_->config.global_id); | |
| 329 | ✗ | meta.lid = port_attr.lid; | |
| 330 | ✗ | meta.psn = kRawVerbsPsn; | |
| 331 | ✗ | meta.rkey = impl_->local_mr->rkey; | |
| 332 | ✗ | meta.base_addr = reinterpret_cast<std::uint64_t>(impl_->local_base); | |
| 333 | ✗ | std::memcpy(meta.gid, &gid, sizeof(meta.gid)); | |
| 334 | ✗ | return meta; | |
| 335 | } | ||
| 336 | |||
| 337 | ✗ | void RawVerbsTransport::PublishAndConnect() { | |
| 338 | ✗ | const int node_count = impl_->config.num_servers + impl_->config.num_clients; | |
| 339 | ✗ | const RawVerbsNodeMeta local = LocalMeta(); | |
| 340 | RdmaControlPlaneClient control_plane({ | ||
| 341 | ✗ | impl_->config.control_plane_host, | |
| 342 | ✗ | impl_->config.control_plane_port, | |
| 343 | ✗ | impl_->config.control_plane_timeout_ms, | |
| 344 | ✗ | }); | |
| 345 | ✗ | for (int node = 0; node < node_count; ++node) { | |
| 346 | ✗ | if (!ShouldRawVerbsConnectToNode(impl_->config, node)) { | |
| 347 | ✗ | continue; | |
| 348 | } | ||
| 349 | ✗ | RawVerbsNodeMeta peer_local = local; | |
| 350 | ✗ | peer_local.qpn = impl_->qps[static_cast<std::size_t>(node)]->qp_num; | |
| 351 | ✗ | control_plane.PublishMeta( | |
| 352 | ✗ | impl_->config.global_id, | |
| 353 | ✗ | impl_->config.local_lane, | |
| 354 | node, | ||
| 355 | ✗ | impl_->config.remote_lane, | |
| 356 | peer_local); | ||
| 357 | } | ||
| 358 | |||
| 359 | ✗ | impl_->metas.assign(static_cast<std::size_t>(node_count), RawVerbsNodeMeta{}); | |
| 360 | ✗ | impl_->remotes.assign( | |
| 361 | ✗ | static_cast<std::size_t>(node_count), RawVerbsRemoteMemory{}); | |
| 362 | |||
| 363 | ✗ | for (int node = 0; node < node_count; ++node) { | |
| 364 | ✗ | if (node == impl_->config.global_id) { | |
| 365 | ✗ | impl_->metas[static_cast<std::size_t>(node)] = local; | |
| 366 | ✗ | impl_->remotes[static_cast<std::size_t>(node)] = RawVerbsRemoteMemory{ | |
| 367 | ✗ | local.node_id, | |
| 368 | ✗ | local.base_addr, | |
| 369 | ✗ | local.rkey, | |
| 370 | }; | ||
| 371 | ✗ | continue; | |
| 372 | } | ||
| 373 | ✗ | if (!ShouldRawVerbsConnectToNode(impl_->config, node)) { | |
| 374 | ✗ | continue; | |
| 375 | } | ||
| 376 | ✗ | const RawVerbsNodeMeta meta = control_plane.GetMeta( | |
| 377 | node, | ||
| 378 | ✗ | impl_->config.remote_lane, | |
| 379 | ✗ | impl_->config.global_id, | |
| 380 | ✗ | impl_->config.local_lane, | |
| 381 | ✗ | impl_->config.control_plane_timeout_ms); | |
| 382 | ✗ | impl_->metas[static_cast<std::size_t>(node)] = meta; | |
| 383 | ✗ | impl_->remotes[static_cast<std::size_t>(node)] = RawVerbsRemoteMemory{ | |
| 384 | ✗ | meta.node_id, | |
| 385 | ✗ | meta.base_addr, | |
| 386 | ✗ | meta.rkey, | |
| 387 | }; | ||
| 388 | } | ||
| 389 | |||
| 390 | ✗ | for (int node = 0; node < node_count; ++node) { | |
| 391 | ✗ | if (!ShouldRawVerbsConnectToNode(impl_->config, node)) { | |
| 392 | ✗ | continue; | |
| 393 | } | ||
| 394 | ✗ | ibv_qp* qp = impl_->qps[static_cast<std::size_t>(node)]; | |
| 395 | ✗ | ModifyQpToInit(qp); | |
| 396 | ✗ | ModifyQpToRtr(qp, impl_->metas[static_cast<std::size_t>(node)]); | |
| 397 | ✗ | ModifyQpToRts(qp); | |
| 398 | ✗ | for (int i = 0; i < kRawVerbsRecvDepth; ++i) { | |
| 399 | ✗ | ibv_recv_wr wr{}; | |
| 400 | ✗ | ibv_recv_wr* bad_wr = nullptr; | |
| 401 | ✗ | wr.wr_id = static_cast<std::uint64_t>(node); | |
| 402 | ✗ | wr.sg_list = nullptr; | |
| 403 | ✗ | wr.num_sge = 0; | |
| 404 | ✗ | if (ibv_post_recv(qp, &wr, &bad_wr) != 0) { | |
| 405 | ✗ | throw std::runtime_error("ibv_post_recv failed"); | |
| 406 | } | ||
| 407 | } | ||
| 408 | } | ||
| 409 | ✗ | } | |
| 410 | |||
| 411 | ✗ | void RawVerbsTransport::Write( | |
| 412 | const void* local, | ||
| 413 | GlobalAddress remote, | ||
| 414 | std::size_t bytes, | ||
| 415 | std::uint64_t wr_id, | ||
| 416 | bool signaled) { | ||
| 417 | ✗ | if (remote.nodeID >= impl_->remotes.size()) { | |
| 418 | ✗ | throw std::runtime_error("raw verbs write remote node out of range"); | |
| 419 | } | ||
| 420 | ✗ | ibv_sge sge{}; | |
| 421 | ✗ | sge.addr = reinterpret_cast<std::uint64_t>(local); | |
| 422 | ✗ | sge.length = static_cast<std::uint32_t>(bytes); | |
| 423 | ✗ | ibv_mr* local_mr = FindLocalMr(local, bytes); | |
| 424 | ✗ | if (local_mr == nullptr) { | |
| 425 | ✗ | throw std::runtime_error("raw verbs write local buffer is not registered"); | |
| 426 | } | ||
| 427 | ✗ | sge.lkey = local_mr->lkey; | |
| 428 | ✗ | ibv_send_wr wr{}; | |
| 429 | ✗ | wr.wr_id = wr_id; | |
| 430 | ✗ | wr.opcode = IBV_WR_RDMA_WRITE; | |
| 431 | ✗ | wr.send_flags = signaled ? IBV_SEND_SIGNALED : 0; | |
| 432 | ✗ | if (bytes > 0 && | |
| 433 | ✗ | bytes <= impl_->max_inline_data_per_node[static_cast<std::size_t>( | |
| 434 | ✗ | remote.nodeID)]) { | |
| 435 | ✗ | wr.send_flags |= IBV_SEND_INLINE; | |
| 436 | } | ||
| 437 | ✗ | wr.sg_list = &sge; | |
| 438 | ✗ | wr.num_sge = 1; | |
| 439 | ✗ | wr.wr.rdma.remote_addr = | |
| 440 | ✗ | impl_->remotes[remote.nodeID].base_addr + remote.offset; | |
| 441 | ✗ | wr.wr.rdma.rkey = impl_->remotes[remote.nodeID].rkey; | |
| 442 | ✗ | ibv_send_wr* bad_wr = nullptr; | |
| 443 | ✗ | if (ibv_post_send(impl_->qps[remote.nodeID], &wr, &bad_wr) != 0) { | |
| 444 | ✗ | throw std::runtime_error("ibv_post_send write failed"); | |
| 445 | } | ||
| 446 | ✗ | } | |
| 447 | |||
| 448 | ✗ | void RawVerbsTransport::WriteSg( | |
| 449 | base::ConstArray<RawVerbsSge> sges, | ||
| 450 | GlobalAddress remote, | ||
| 451 | std::uint64_t wr_id, | ||
| 452 | bool signaled) { | ||
| 453 | ✗ | if (remote.nodeID >= impl_->remotes.size()) { | |
| 454 | ✗ | throw std::runtime_error("raw verbs write-sg remote node out of range"); | |
| 455 | } | ||
| 456 | ✗ | if (sges.Size() == 0) { | |
| 457 | ✗ | return; | |
| 458 | } | ||
| 459 | ✗ | std::vector<ibv_sge> verbs_sges; | |
| 460 | ✗ | verbs_sges.reserve(static_cast<std::size_t>(sges.Size())); | |
| 461 | ✗ | for (const auto& entry : sges) { | |
| 462 | ✗ | if (entry.bytes == 0) { | |
| 463 | ✗ | continue; | |
| 464 | } | ||
| 465 | ✗ | if (entry.bytes > std::numeric_limits<std::uint32_t>::max()) { | |
| 466 | ✗ | throw std::runtime_error("raw verbs write-sg entry too large"); | |
| 467 | } | ||
| 468 | ✗ | ibv_mr* local_mr = FindLocalMr(entry.data, entry.bytes); | |
| 469 | ✗ | if (local_mr == nullptr) { | |
| 470 | ✗ | throw std::runtime_error( | |
| 471 | ✗ | "raw verbs write-sg local buffer is not registered"); | |
| 472 | } | ||
| 473 | ✗ | ibv_sge sge{}; | |
| 474 | ✗ | sge.addr = reinterpret_cast<std::uint64_t>(entry.data); | |
| 475 | ✗ | sge.length = static_cast<std::uint32_t>(entry.bytes); | |
| 476 | ✗ | sge.lkey = local_mr->lkey; | |
| 477 | ✗ | verbs_sges.push_back(sge); | |
| 478 | } | ||
| 479 | ✗ | if (verbs_sges.empty()) { | |
| 480 | ✗ | return; | |
| 481 | } | ||
| 482 | ✗ | ibv_send_wr wr{}; | |
| 483 | ✗ | wr.wr_id = wr_id; | |
| 484 | ✗ | wr.opcode = IBV_WR_RDMA_WRITE; | |
| 485 | ✗ | wr.send_flags = signaled ? IBV_SEND_SIGNALED : 0; | |
| 486 | ✗ | wr.sg_list = verbs_sges.data(); | |
| 487 | ✗ | wr.num_sge = static_cast<int>(verbs_sges.size()); | |
| 488 | ✗ | wr.wr.rdma.remote_addr = | |
| 489 | ✗ | impl_->remotes[remote.nodeID].base_addr + remote.offset; | |
| 490 | ✗ | wr.wr.rdma.rkey = impl_->remotes[remote.nodeID].rkey; | |
| 491 | ✗ | ibv_send_wr* bad_wr = nullptr; | |
| 492 | ✗ | if (ibv_post_send(impl_->qps[remote.nodeID], &wr, &bad_wr) != 0) { | |
| 493 | ✗ | throw std::runtime_error("ibv_post_send write-sg failed"); | |
| 494 | } | ||
| 495 | ✗ | } | |
| 496 | |||
| 497 | ✗ | void RawVerbsTransport::WriteWithImm( | |
| 498 | const void* local, | ||
| 499 | GlobalAddress remote, | ||
| 500 | std::size_t bytes, | ||
| 501 | std::uint32_t imm_data, | ||
| 502 | std::uint64_t wr_id, | ||
| 503 | bool signaled) { | ||
| 504 | ✗ | if (remote.nodeID >= impl_->remotes.size()) { | |
| 505 | ✗ | throw std::runtime_error( | |
| 506 | ✗ | "raw verbs write-with-imm remote node out of range"); | |
| 507 | } | ||
| 508 | ✗ | ibv_sge sge{}; | |
| 509 | ✗ | sge.addr = reinterpret_cast<std::uint64_t>(local); | |
| 510 | ✗ | sge.length = static_cast<std::uint32_t>(bytes); | |
| 511 | ✗ | sge.lkey = impl_->local_mr->lkey; | |
| 512 | ✗ | ibv_send_wr wr{}; | |
| 513 | ✗ | wr.wr_id = wr_id; | |
| 514 | ✗ | wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; | |
| 515 | ✗ | wr.imm_data = htonl(imm_data); | |
| 516 | ✗ | wr.send_flags = signaled ? IBV_SEND_SIGNALED : 0; | |
| 517 | ✗ | if (bytes > 0 && | |
| 518 | ✗ | bytes <= impl_->max_inline_data_per_node[static_cast<std::size_t>( | |
| 519 | ✗ | remote.nodeID)]) { | |
| 520 | ✗ | wr.send_flags |= IBV_SEND_INLINE; | |
| 521 | } | ||
| 522 | ✗ | wr.sg_list = &sge; | |
| 523 | ✗ | wr.num_sge = 1; | |
| 524 | ✗ | wr.wr.rdma.remote_addr = | |
| 525 | ✗ | impl_->remotes[remote.nodeID].base_addr + remote.offset; | |
| 526 | ✗ | wr.wr.rdma.rkey = impl_->remotes[remote.nodeID].rkey; | |
| 527 | ✗ | ibv_send_wr* bad_wr = nullptr; | |
| 528 | ✗ | if (ibv_post_send(impl_->qps[remote.nodeID], &wr, &bad_wr) != 0) { | |
| 529 | ✗ | throw std::runtime_error("ibv_post_send write-with-imm failed"); | |
| 530 | } | ||
| 531 | ✗ | } | |
| 532 | |||
| 533 | ✗ | void RawVerbsTransport::Read( | |
| 534 | void* local, | ||
| 535 | GlobalAddress remote, | ||
| 536 | std::size_t bytes, | ||
| 537 | std::uint64_t wr_id, | ||
| 538 | bool signaled) { | ||
| 539 | ✗ | if (remote.nodeID >= impl_->remotes.size()) { | |
| 540 | ✗ | throw std::runtime_error("raw verbs read remote node out of range"); | |
| 541 | } | ||
| 542 | ✗ | ibv_sge sge{}; | |
| 543 | ✗ | sge.addr = reinterpret_cast<std::uint64_t>(local); | |
| 544 | ✗ | sge.length = static_cast<std::uint32_t>(bytes); | |
| 545 | ✗ | sge.lkey = impl_->local_mr->lkey; | |
| 546 | ✗ | ibv_send_wr wr{}; | |
| 547 | ✗ | wr.wr_id = wr_id; | |
| 548 | ✗ | wr.opcode = IBV_WR_RDMA_READ; | |
| 549 | ✗ | wr.send_flags = signaled ? IBV_SEND_SIGNALED : 0; | |
| 550 | ✗ | wr.sg_list = &sge; | |
| 551 | ✗ | wr.num_sge = 1; | |
| 552 | ✗ | wr.wr.rdma.remote_addr = | |
| 553 | ✗ | impl_->remotes[remote.nodeID].base_addr + remote.offset; | |
| 554 | ✗ | wr.wr.rdma.rkey = impl_->remotes[remote.nodeID].rkey; | |
| 555 | ✗ | ibv_send_wr* bad_wr = nullptr; | |
| 556 | ✗ | if (ibv_post_send(impl_->qps[remote.nodeID], &wr, &bad_wr) != 0) { | |
| 557 | ✗ | throw std::runtime_error("ibv_post_send read failed"); | |
| 558 | } | ||
| 559 | ✗ | } | |
| 560 | |||
| 561 | ✗ | void RawVerbsTransport::SendDoorbell( | |
| 562 | std::uint16_t node_id, std::uint32_t imm_data, std::uint64_t wr_id) { | ||
| 563 | ✗ | if (node_id >= impl_->qps.size()) { | |
| 564 | ✗ | throw std::runtime_error("raw verbs doorbell remote node out of range"); | |
| 565 | } | ||
| 566 | ✗ | ibv_send_wr wr{}; | |
| 567 | ✗ | wr.opcode = IBV_WR_SEND_WITH_IMM; | |
| 568 | ✗ | wr.imm_data = htonl(imm_data); | |
| 569 | ✗ | wr.send_flags = IBV_SEND_SIGNALED; | |
| 570 | ✗ | wr.wr_id = wr_id; | |
| 571 | ✗ | wr.sg_list = nullptr; | |
| 572 | ✗ | wr.num_sge = 0; | |
| 573 | ✗ | ibv_send_wr* bad_wr = nullptr; | |
| 574 | ✗ | if (ibv_post_send(impl_->qps[node_id], &wr, &bad_wr) != 0) { | |
| 575 | ✗ | throw std::runtime_error("ibv_post_send doorbell failed"); | |
| 576 | } | ||
| 577 | ✗ | } | |
| 578 | |||
| 579 | ✗ | bool RawVerbsTransport::Poll(RawVerbsCompletion* completion, int timeout_ms) { | |
| 580 | const auto deadline = | ||
| 581 | ✗ | timeout_ms > 0 | |
| 582 | ✗ | ? std::chrono::steady_clock::now() + | |
| 583 | ✗ | std::chrono::milliseconds(timeout_ms) | |
| 584 | ✗ | : std::chrono::steady_clock::time_point::max(); | |
| 585 | while (true) { | ||
| 586 | ✗ | if (!impl_->batch_cursor.HasCachedCompletion()) { | |
| 587 | const int n = | ||
| 588 | ✗ | ibv_poll_cq(impl_->cq, kRawVerbsPollBatchSize, impl_->wc_batch); | |
| 589 | ✗ | if (n < 0) { | |
| 590 | ✗ | throw std::runtime_error("ibv_poll_cq failed"); | |
| 591 | } | ||
| 592 | ✗ | if (n == 0) { | |
| 593 | ✗ | if (std::chrono::steady_clock::now() >= deadline) { | |
| 594 | ✗ | return false; | |
| 595 | } | ||
| 596 | ✗ | std::this_thread::yield(); | |
| 597 | ✗ | continue; | |
| 598 | } | ||
| 599 | ✗ | impl_->batch_cursor.Reset(impl_->wc_batch, n); | |
| 600 | } | ||
| 601 | ✗ | ibv_wc& wc = *impl_->batch_cursor.TakeCachedCompletion(); | |
| 602 | ✗ | if (wc.status != IBV_WC_SUCCESS) { | |
| 603 | ✗ | throw std::runtime_error( | |
| 604 | ✗ | std::string("raw verbs CQ error: ") + ibv_wc_status_str(wc.status)); | |
| 605 | } | ||
| 606 | ✗ | if (completion != nullptr) { | |
| 607 | ✗ | completion->wr_id = wc.wr_id; | |
| 608 | ✗ | completion->opcode = wc.opcode; | |
| 609 | ✗ | completion->has_imm = (wc.wc_flags & IBV_WC_WITH_IMM) != 0; | |
| 610 | ✗ | completion->imm_data = completion->has_imm ? ntohl(wc.imm_data) : 0; | |
| 611 | } | ||
| 612 | ✗ | if (wc.opcode == IBV_WC_RECV || wc.opcode == IBV_WC_RECV_RDMA_WITH_IMM) { | |
| 613 | ✗ | const std::uint16_t node_id = static_cast<std::uint16_t>(wc.wr_id); | |
| 614 | ✗ | ibv_recv_wr wr{}; | |
| 615 | ✗ | ibv_recv_wr* bad_wr = nullptr; | |
| 616 | ✗ | wr.wr_id = node_id; | |
| 617 | ✗ | wr.sg_list = nullptr; | |
| 618 | ✗ | wr.num_sge = 0; | |
| 619 | ✗ | if (node_id < impl_->qps.size() && | |
| 620 | ✗ | ibv_post_recv(impl_->qps[node_id], &wr, &bad_wr) != 0) { | |
| 621 | ✗ | throw std::runtime_error("ibv_post_recv repost failed"); | |
| 622 | } | ||
| 623 | } | ||
| 624 | ✗ | return true; | |
| 625 | ✗ | } | |
| 626 | } | ||
| 627 | |||
| 628 | ✗ | std::uint32_t RawVerbsTransport::max_inline_data(std::uint16_t node_id) const { | |
| 629 | ✗ | if (node_id >= impl_->max_inline_data_per_node.size()) { | |
| 630 | ✗ | throw std::runtime_error("raw verbs inline query remote node out of range"); | |
| 631 | } | ||
| 632 | ✗ | return impl_->max_inline_data_per_node[static_cast<std::size_t>(node_id)]; | |
| 633 | } | ||
| 634 | |||
| 635 | } // namespace petps | ||
| 636 |