ps/rdma/raw_verbs_transport.h
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #pragma once | ||
| 2 | |||
| 3 | #include <cstddef> | ||
| 4 | #include <cstdint> | ||
| 5 | #include <memory> | ||
| 6 | #include <stdexcept> | ||
| 7 | #include <string> | ||
| 8 | #include <vector> | ||
| 9 | |||
| 10 | #include <infiniband/verbs.h> | ||
| 11 | |||
| 12 | #include "base/array.h" | ||
| 13 | #include "ps/rdma/global_address.h" | ||
| 14 | |||
| 15 | namespace petps { | ||
| 16 | |||
| 17 | 10 | inline int SelectRawVerbsDeviceIndex(int numa_id, int device_count) { | |
| 18 |
4/4✓ Branch 0 taken 8 times.
✓ Branch 1 taken 2 times.
✓ Branch 2 taken 4 times.
✓ Branch 3 taken 4 times.
|
10 | if (device_count <= 0 || numa_id <= 0) { |
| 19 | 6 | return 0; | |
| 20 | } | ||
| 21 |
2/2✓ Branch 0 taken 2 times.
✓ Branch 1 taken 2 times.
|
4 | if (numa_id >= device_count) { |
| 22 | 2 | return device_count - 1; | |
| 23 | } | ||
| 24 | 2 | return numa_id; | |
| 25 | } | ||
| 26 | |||
| 27 | struct RawVerbsConfig { | ||
| 28 | int global_id = 0; | ||
| 29 | int local_lane = 0; | ||
| 30 | int remote_lane = 0; | ||
| 31 | int only_node_id = -1; // Optional single peer node filter. | ||
| 32 | int num_servers = 1; | ||
| 33 | int num_clients = 1; | ||
| 34 | int numa_id = 0; | ||
| 35 | std::uint32_t max_inline_data = 0; | ||
| 36 | bool connect_to_servers = true; | ||
| 37 | bool connect_to_clients = true; | ||
| 38 | std::size_t local_region_bytes = 128 * 1024 * 1024; | ||
| 39 | std::uint64_t local_base_addr = 0; | ||
| 40 | std::string control_plane_host = "127.0.0.1"; | ||
| 41 | int control_plane_port = 25100; | ||
| 42 | int control_plane_timeout_ms = 30000; | ||
| 43 | std::uint64_t allocation_start_offset = 0; | ||
| 44 | std::uint64_t reserved_region_offset = 0; | ||
| 45 | std::uint64_t reserved_region_bytes = 0; | ||
| 46 | }; | ||
| 47 | |||
| 48 | 2 | inline std::string RawVerbsMetaKey( | |
| 49 | int publisher_node_id, | ||
| 50 | int publisher_lane, | ||
| 51 | int receiver_node_id, | ||
| 52 | int receiver_lane) { | ||
| 53 |
4/8✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 11 not taken.
|
4 | return "raw-verbs-meta-" + std::to_string(publisher_node_id) + "-lane-" + |
| 54 |
3/6✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 8 not taken.
|
8 | std::to_string(publisher_lane) + "-to-" + |
| 55 |
2/4✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
|
8 | std::to_string(receiver_node_id) + "-lane-" + |
| 56 |
1/2✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | std::to_string(receiver_lane); |
| 57 | } | ||
| 58 | |||
| 59 | inline bool | ||
| 60 | 16 | ShouldRawVerbsConnectToNode(const RawVerbsConfig& config, int node_id) { | |
| 61 |
2/2✓ Branch 0 taken 4 times.
✓ Branch 1 taken 12 times.
|
16 | if (node_id == config.global_id) { |
| 62 | 4 | return false; | |
| 63 | } | ||
| 64 |
2/4✓ Branch 0 taken 12 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 12 times.
|
12 | if (node_id < 0 || node_id >= config.num_servers + config.num_clients) { |
| 65 | ✗ | return false; | |
| 66 | } | ||
| 67 |
1/4✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
12 | if (config.only_node_id >= 0 && node_id != config.only_node_id) { |
| 68 | ✗ | return false; | |
| 69 | } | ||
| 70 | 12 | const bool is_server = node_id < config.num_servers; | |
| 71 |
2/2✓ Branch 0 taken 6 times.
✓ Branch 1 taken 6 times.
|
12 | return is_server ? config.connect_to_servers : config.connect_to_clients; |
| 72 | } | ||
| 73 | |||
| 74 | struct RawVerbsReservedRegion { | ||
| 75 | std::uint64_t offset = 0; | ||
| 76 | std::uint64_t bytes = 0; | ||
| 77 | }; | ||
| 78 | |||
| 79 | class RawVerbsRegionAllocator { | ||
| 80 | public: | ||
| 81 | 6 | explicit RawVerbsRegionAllocator(std::uint64_t limit_bytes, | |
| 82 | std::uint64_t allocation_start_offset = 0) | ||
| 83 | 6 | : limit_bytes_(limit_bytes), allocation_offset_(allocation_start_offset) { | |
| 84 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
|
6 | if (allocation_start_offset > limit_bytes) { |
| 85 | ✗ | throw std::runtime_error( | |
| 86 | ✗ | "raw verbs allocation start outside local memory"); | |
| 87 | } | ||
| 88 | 6 | } | |
| 89 | |||
| 90 | 6 | void SetReservedRegion(RawVerbsReservedRegion reserved) { | |
| 91 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | if (reserved.bytes != 0) { |
| 92 | 6 | const std::uint64_t reserved_end = reserved.offset + reserved.bytes; | |
| 93 |
3/4✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✓ Branch 3 taken 4 times.
|
6 | if (reserved_end < reserved.offset || reserved_end > limit_bytes_) { |
| 94 |
1/2✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
|
2 | throw std::runtime_error( |
| 95 | 4 | "raw verbs reserved region outside local memory"); | |
| 96 | } | ||
| 97 | } | ||
| 98 | 4 | reserved_ = reserved; | |
| 99 | 4 | const std::uint64_t reserved_end = reserved.offset + reserved.bytes; | |
| 100 |
2/4✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 4 times.
|
4 | if (reserved.bytes != 0 && allocation_offset_ >= reserved.offset && |
| 101 | ✗ | allocation_offset_ < reserved_end) { | |
| 102 | ✗ | allocation_offset_ = reserved_end; | |
| 103 | } | ||
| 104 | 4 | } | |
| 105 | |||
| 106 | 12 | std::uint64_t Allocate(std::size_t bytes) { | |
| 107 | 12 | const std::uint64_t aligned = Align(bytes); | |
| 108 | 12 | std::uint64_t offset = allocation_offset_; | |
| 109 | 12 | const std::uint64_t reserved_begin = reserved_.offset; | |
| 110 | 12 | const std::uint64_t reserved_end = reserved_.offset + reserved_.bytes; | |
| 111 |
1/2✓ Branch 0 taken 12 times.
✗ Branch 1 not taken.
|
12 | if (reserved_.bytes != 0) { |
| 112 |
3/4✓ Branch 0 taken 2 times.
✓ Branch 1 taken 10 times.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
12 | if (offset >= reserved_begin && offset < reserved_end) { |
| 113 | 2 | offset = reserved_end; | |
| 114 |
2/4✓ Branch 0 taken 10 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 10 times.
|
10 | } else if (offset < reserved_begin && offset + aligned > reserved_begin) { |
| 115 | ✗ | offset = reserved_end; | |
| 116 | } | ||
| 117 | } | ||
| 118 |
2/4✓ Branch 0 taken 12 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 12 times.
|
12 | if (offset + aligned < offset || offset + aligned > limit_bytes_) { |
| 119 | ✗ | throw std::runtime_error("raw verbs registered region exhausted"); | |
| 120 | } | ||
| 121 | 12 | allocation_offset_ = offset + aligned; | |
| 122 | 12 | return offset; | |
| 123 | } | ||
| 124 | |||
| 125 | 2 | std::uint64_t Checkpoint() const { return allocation_offset_; } | |
| 126 | |||
| 127 | 2 | void Restore(std::uint64_t checkpoint) { | |
| 128 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
|
2 | if (checkpoint > limit_bytes_) { |
| 129 | ✗ | throw std::runtime_error( | |
| 130 | ✗ | "raw verbs allocation checkpoint outside local memory"); | |
| 131 | } | ||
| 132 | 2 | allocation_offset_ = checkpoint; | |
| 133 | 2 | } | |
| 134 | |||
| 135 | 6 | std::uint64_t current_offset() const { return allocation_offset_; } | |
| 136 | |||
| 137 | private: | ||
| 138 | 12 | static std::uint64_t Align(std::size_t bytes) { | |
| 139 | 12 | return (static_cast<std::uint64_t>(bytes) + 63) & ~std::uint64_t{63}; | |
| 140 | } | ||
| 141 | |||
| 142 | std::uint64_t limit_bytes_ = 0; | ||
| 143 | std::uint64_t allocation_offset_ = 0; | ||
| 144 | RawVerbsReservedRegion reserved_{}; | ||
| 145 | }; | ||
| 146 | |||
| 147 | class RawVerbsRegionAllocatorScope { | ||
| 148 | public: | ||
| 149 | 2 | explicit RawVerbsRegionAllocatorScope(RawVerbsRegionAllocator* allocator) | |
| 150 | 2 | : allocator_(allocator), | |
| 151 |
1/2✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
|
2 | checkpoint_(allocator != nullptr ? allocator->Checkpoint() : 0) {} |
| 152 | |||
| 153 | 2 | ~RawVerbsRegionAllocatorScope() { | |
| 154 |
1/2✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
|
2 | if (allocator_ != nullptr) { |
| 155 | 2 | allocator_->Restore(checkpoint_); | |
| 156 | } | ||
| 157 | 2 | } | |
| 158 | |||
| 159 | RawVerbsRegionAllocatorScope(const RawVerbsRegionAllocatorScope&) = delete; | ||
| 160 | RawVerbsRegionAllocatorScope& | ||
| 161 | operator=(const RawVerbsRegionAllocatorScope&) = delete; | ||
| 162 | |||
| 163 | private: | ||
| 164 | RawVerbsRegionAllocator* allocator_ = nullptr; | ||
| 165 | std::uint64_t checkpoint_ = 0; | ||
| 166 | }; | ||
| 167 | |||
| 168 | struct RawVerbsRemoteMemory { | ||
| 169 | std::uint16_t node_id = 0; | ||
| 170 | std::uint64_t base_addr = 0; | ||
| 171 | std::uint32_t rkey = 0; | ||
| 172 | }; | ||
| 173 | |||
| 174 | struct RawVerbsCompletion { | ||
| 175 | std::uint64_t wr_id = 0; | ||
| 176 | std::uint32_t imm_data = 0; | ||
| 177 | bool has_imm = false; | ||
| 178 | ibv_wc_opcode opcode = IBV_WC_SEND; | ||
| 179 | }; | ||
| 180 | |||
| 181 | struct RawVerbsSge { | ||
| 182 | const void* data = nullptr; | ||
| 183 | std::size_t bytes = 0; | ||
| 184 | }; | ||
| 185 | |||
| 186 | inline constexpr int kRawVerbsPollBatchSize = 16; | ||
| 187 | |||
| 188 | class RawVerbsCompletionBatchCursor { | ||
| 189 | public: | ||
| 190 | ✗ | bool HasCachedCompletion() const { return current_ < size_; } | |
| 191 | |||
| 192 | ✗ | void Reset(ibv_wc* entries, int size) { | |
| 193 | ✗ | entries_ = entries; | |
| 194 | ✗ | size_ = size; | |
| 195 | ✗ | current_ = 0; | |
| 196 | ✗ | } | |
| 197 | |||
| 198 | ✗ | ibv_wc* TakeCachedCompletion() { | |
| 199 | ✗ | if (!HasCachedCompletion()) { | |
| 200 | ✗ | return nullptr; | |
| 201 | } | ||
| 202 | ✗ | return entries_ + current_++; | |
| 203 | } | ||
| 204 | |||
| 205 | private: | ||
| 206 | ibv_wc* entries_ = nullptr; | ||
| 207 | int size_ = 0; | ||
| 208 | int current_ = 0; | ||
| 209 | }; | ||
| 210 | |||
| 211 | struct RawVerbsNodeMeta { | ||
| 212 | std::uint16_t node_id = 0; | ||
| 213 | std::uint16_t lid = 0; | ||
| 214 | std::uint32_t qpn = 0; | ||
| 215 | std::uint32_t psn = 3185; | ||
| 216 | std::uint32_t rkey = 0; | ||
| 217 | std::uint64_t base_addr = 0; | ||
| 218 | std::uint8_t gid[16] = {}; | ||
| 219 | }; | ||
| 220 | |||
| 221 | class RawVerbsTransport { | ||
| 222 | public: | ||
| 223 | explicit RawVerbsTransport(const RawVerbsConfig& config); | ||
| 224 | ~RawVerbsTransport(); | ||
| 225 | |||
| 226 | RawVerbsTransport(const RawVerbsTransport&) = delete; | ||
| 227 | RawVerbsTransport& operator=(const RawVerbsTransport&) = delete; | ||
| 228 | |||
| 229 | void RegisterThread(); | ||
| 230 | void RegisterMemoryRegion(void* base, std::size_t bytes); | ||
| 231 | void* AllocateRegistered(std::size_t bytes); | ||
| 232 | std::uint64_t SaveAllocationState() const; | ||
| 233 | void RestoreAllocationState(std::uint64_t checkpoint); | ||
| 234 | GlobalAddress LocalAddress(void* ptr) const; | ||
| 235 | void* LocalPointer(GlobalAddress address) const; | ||
| 236 | |||
| 237 | void PublishAndConnect(); | ||
| 238 | RawVerbsNodeMeta LocalMeta() const; | ||
| 239 | |||
| 240 | void Write(const void* local, | ||
| 241 | GlobalAddress remote, | ||
| 242 | std::size_t bytes, | ||
| 243 | std::uint64_t wr_id, | ||
| 244 | bool signaled); | ||
| 245 | void WriteSg(base::ConstArray<RawVerbsSge> sges, | ||
| 246 | GlobalAddress remote, | ||
| 247 | std::uint64_t wr_id, | ||
| 248 | bool signaled); | ||
| 249 | void WriteWithImm(const void* local, | ||
| 250 | GlobalAddress remote, | ||
| 251 | std::size_t bytes, | ||
| 252 | std::uint32_t imm_data, | ||
| 253 | std::uint64_t wr_id, | ||
| 254 | bool signaled); | ||
| 255 | void Read(void* local, | ||
| 256 | GlobalAddress remote, | ||
| 257 | std::size_t bytes, | ||
| 258 | std::uint64_t wr_id, | ||
| 259 | bool signaled); | ||
| 260 | void SendDoorbell( | ||
| 261 | std::uint16_t node_id, std::uint32_t imm_data, std::uint64_t wr_id); | ||
| 262 | bool Poll(RawVerbsCompletion* completion, int timeout_ms); | ||
| 263 | std::uint32_t max_inline_data(std::uint16_t node_id) const; | ||
| 264 | |||
| 265 | private: | ||
| 266 | struct Impl; | ||
| 267 | ibv_mr* FindLocalMr(const void* ptr, std::size_t bytes) const; | ||
| 268 | std::unique_ptr<Impl> impl_; | ||
| 269 | }; | ||
| 270 | |||
| 271 | class RawVerbsTransportAllocationScope { | ||
| 272 | public: | ||
| 273 | explicit RawVerbsTransportAllocationScope(RawVerbsTransport* transport) | ||
| 274 | : transport_(transport), | ||
| 275 | checkpoint_( | ||
| 276 | transport != nullptr ? transport->SaveAllocationState() : 0) {} | ||
| 277 | |||
| 278 | ~RawVerbsTransportAllocationScope() { | ||
| 279 | if (transport_ != nullptr) { | ||
| 280 | transport_->RestoreAllocationState(checkpoint_); | ||
| 281 | } | ||
| 282 | } | ||
| 283 | |||
| 284 | RawVerbsTransportAllocationScope(const RawVerbsTransportAllocationScope&) = | ||
| 285 | delete; | ||
| 286 | RawVerbsTransportAllocationScope& | ||
| 287 | operator=(const RawVerbsTransportAllocationScope&) = delete; | ||
| 288 | |||
| 289 | private: | ||
| 290 | RawVerbsTransport* transport_ = nullptr; | ||
| 291 | std::uint64_t checkpoint_ = 0; | ||
| 292 | }; | ||
| 293 | |||
| 294 | } // namespace petps | ||
| 295 |