GCC Code Coverage Report


Directory: src/
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 32.1% 26 / 0 / 81
Functions: 28.6% 4 / 0 / 14
Branches: 23.3% 14 / 0 / 60

ps/rdma/rdma_protocol.h
Line Branch Exec Source
1 #pragma once
2
3 #include <algorithm>
4 #include <array>
5 #include <atomic>
6 #include <cstddef>
7 #include <cstdint>
8 #include <cstring>
9 #include <string>
10 #include <string_view>
11
12 #include "base/flatc.h"
13 #include "base/log.h"
14 #include "ps/base/parameters.h"
15 #include "ps/rdma/rdma_status.h"
16
17 namespace petps {
18
19 inline constexpr std::uint32_t kRcProtocolMagic = 0x52435053;
20 inline constexpr std::uint16_t kRcProtocolVersion = 1;
21 inline constexpr std::size_t kTableNameBytes = 64;
22 inline constexpr std::uint32_t kRcSlotReady = 1;
23 inline constexpr std::uint32_t kRcSlotDone = 2;
24 inline constexpr std::uint32_t kRcFlagGetDirectSg = 1U << 0;
25 inline constexpr std::uint32_t kRcFlagGetAllowFallbackCopy = 1U << 1;
26
27 enum class RcOp : std::uint16_t {
28 kGet = 1,
29 kPut = 2,
30 kUpdate = 3,
31 kInitTable = 4,
32 };
33
34 enum class RcHashMethod : std::uint8_t {
35 kCityHash = 1,
36 kSimpleMod = 2,
37 };
38
39 struct alignas(64) RequestDescriptor {
40 std::uint32_t magic = kRcProtocolMagic;
41 std::uint16_t version = kRcProtocolVersion;
42 std::uint16_t op = static_cast<std::uint16_t>(RcOp::kGet);
43 std::uint64_t seq = 0; // Monotonic lane-local request sequence.
44 std::uint32_t shard_id = 0; // Logical shard targeted by this RPC.
45 std::uint32_t client_id = 0; // Logical client owner of this lane.
46 std::uint32_t qp_index = 0; // Lane index within the client.
47 std::uint32_t key_count = 0; // Number of keys in the payload.
48 std::uint32_t value_size = 0; // Row size in bytes for GET responses.
49 std::uint32_t embedding_dim = 0; // Row size expressed as float count.
50 std::uint32_t payload_offset = 0; // Offset from slot base to payload.
51 std::uint32_t payload_bytes = 0; // Bytes occupied by the payload.
52 std::uint32_t response_bytes = 0; // Bytes expected in the response payload.
53 std::uint32_t reserved0 = 0;
54 std::uint64_t client_response_addr =
55 0; // Optional client response address for verbs RC.
56 std::uint32_t client_response_rkey =
57 0; // Optional client response remote key.
58 std::uint32_t client_status_rkey = 0; // Optional client status remote key.
59 std::uint64_t client_status_addr =
60 0; // Optional client status address for verbs RC.
61 std::uint32_t flags = 0; // Op-specific protocol flags.
62 std::uint32_t reserved1 = 0;
63 std::array<char, kTableNameBytes>
64 table_name{}; // Optional logical table name.
65 };
66
67 struct alignas(64) CommitWord {
68 std::atomic<std::uint64_t> seq{0}; // Mirrors RequestDescriptor::seq.
69 std::atomic<std::uint32_t> state{
70 0}; // READY/DONE state published by client/server.
71 std::uint32_t checksum_or_reserved =
72 0; // Reserved for future integrity checks.
73 };
74
75 struct alignas(64) StatusWord {
76 std::atomic<std::uint64_t> seq{0}; // Mirrors the request seq completed here.
77 std::atomic<std::uint32_t> state{
78 0}; // DONE when the server has finished writing.
79 std::int32_t status = 0; // RpcStatus value returned by the server.
80 std::uint32_t response_bytes = 0; // Payload bytes valid in the response slot.
81 std::uint32_t reserved = 0;
82 };
83
84 static_assert(sizeof(RequestDescriptor) == 192, "RequestDescriptor size");
85 static_assert(alignof(RequestDescriptor) == 64, "RequestDescriptor align");
86 static_assert(alignof(CommitWord) == 64, "CommitWord align");
87 static_assert(alignof(StatusWord) == 64, "StatusWord align");
88
89 inline std::size_t Align64(std::size_t value) {
90 return (value + 63U) & ~std::size_t{63U};
91 }
92
93 2 inline std::size_t GetKeysPerRpcByResponseBudget(
94 std::size_t value_size, std::size_t mtu_bytes, std::size_t response_mtu) {
95
3/6
✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
2 if (value_size == 0 || mtu_bytes == 0 || response_mtu == 0) {
96 return 0;
97 }
98 2 return (mtu_bytes * response_mtu) / value_size;
99 }
100
101 inline std::size_t GetRequestBytes(std::size_t key_count) {
102 return key_count * sizeof(std::uint64_t);
103 }
104
105 inline std::size_t
106 GetResponseBytes(std::size_t key_count, std::size_t value_size) {
107 return key_count * value_size;
108 }
109
110 inline std::size_t
111 FixedSlotResponseBytes(std::size_t key_count, std::size_t value_size) {
112 return GetResponseBytes(key_count, value_size) + sizeof(std::int32_t);
113 }
114
115 inline std::size_t InitTablePayloadBytes() { return sizeof(std::uint64_t) * 2; }
116
117 inline std::size_t PutPayloadBudget(std::size_t request_slot_bytes) {
118 if (request_slot_bytes <=
119 Align64(sizeof(RequestDescriptor)) + Align64(sizeof(CommitWord))) {
120 return 0;
121 }
122 return request_slot_bytes - Align64(sizeof(RequestDescriptor)) -
123 Align64(sizeof(CommitWord));
124 }
125
126 inline std::size_t ParameterReaderBytes(const ParameterCompressReader& reader) {
127 return static_cast<std::size_t>(reader.byte_size());
128 }
129
130 2 inline std::size_t PutPayloadBytes(
131 const std::vector<std::uint64_t>& keys,
132 const std::vector<std::vector<float>>& values,
133 std::string* payload,
134 std::string* error = nullptr) {
135
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
2 if (payload == nullptr) {
136 if (error != nullptr) {
137 *error = "payload buffer is null";
138 }
139 return 0;
140 }
141
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
2 if (keys.size() != values.size()) {
142 if (error != nullptr) {
143 *error = "keys and values size mismatch";
144 }
145 return 0;
146 }
147
148
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 ParameterCompressor compressor;
149
2/2
✓ Branch 1 taken 4 times.
✓ Branch 2 taken 2 times.
6 for (std::size_t i = 0; i < keys.size(); ++i) {
150 4 ParameterPack pack;
151 4 pack.key = keys[i];
152 4 pack.dim = static_cast<int>(values[i].size());
153 4 pack.emb_data = values[i].data();
154
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 compressor.AddItem(pack, nullptr);
155 }
156
157 2 payload->clear();
158
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 compressor.ToBlock(payload);
159 2 return payload->size();
160 2 }
161
162 inline std::size_t UpdatePayloadBytes(
163 const std::vector<std::uint64_t>& keys,
164 const std::vector<std::vector<float>>& values,
165 std::string* payload,
166 std::string* error = nullptr) {
167 return PutPayloadBytes(keys, values, payload, error);
168 }
169
170 inline bool CopyTableName(std::string_view table_name,
171 std::array<char, kTableNameBytes>* storage) {
172 if (storage == nullptr || table_name.size() >= kTableNameBytes) {
173 return false;
174 }
175 storage->fill('\0');
176 std::memcpy(storage->data(), table_name.data(), table_name.size());
177 return true;
178 }
179
180 inline std::string_view
181 DescriptorTableName(const RequestDescriptor& descriptor) {
182 return std::string_view(
183 descriptor.table_name.data(),
184 std::find(
185 descriptor.table_name.begin(), descriptor.table_name.end(), '\0') -
186 descriptor.table_name.begin());
187 }
188
189 inline bool ValidateRequestDescriptor(
190 const RequestDescriptor& descriptor,
191 std::size_t request_slot_bytes,
192 std::size_t response_slot_bytes,
193 std::string* error = nullptr) {
194 if (descriptor.magic != kRcProtocolMagic) {
195 if (error != nullptr) {
196 *error = "bad request magic";
197 }
198 return false;
199 }
200 if (descriptor.version != kRcProtocolVersion) {
201 if (error != nullptr) {
202 *error = "bad request version";
203 }
204 return false;
205 }
206 if (descriptor.payload_offset < sizeof(RequestDescriptor) ||
207 static_cast<std::size_t>(descriptor.payload_offset) +
208 descriptor.payload_bytes >
209 request_slot_bytes) {
210 if (error != nullptr) {
211 *error = "request payload exceeds slot capacity";
212 }
213 return false;
214 }
215 if (descriptor.response_bytes > response_slot_bytes) {
216 if (error != nullptr) {
217 *error = "response exceeds slot capacity";
218 }
219 return false;
220 }
221 return true;
222 }
223
224 2 inline void ResetStatusWord(StatusWord* status, std::uint64_t seq) {
225 2 status->status = static_cast<std::int32_t>(RpcStatus::kPending);
226 2 status->response_bytes = 0;
227 2 status->seq.store(seq, std::memory_order_release);
228 2 status->state.store(0, std::memory_order_release);
229 2 }
230
231 6 inline bool StatusWordDone(const StatusWord& status, std::uint64_t seq) {
232
2/2
✓ Branch 1 taken 4 times.
✓ Branch 2 taken 2 times.
16 return status.state.load(std::memory_order_acquire) == kRcSlotDone &&
233
2/2
✓ Branch 1 taken 2 times.
✓ Branch 2 taken 2 times.
14 status.seq.load(std::memory_order_acquire) == seq;
234 }
235
236 } // namespace petps
237