GCC Code Coverage Report


Directory: src/
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 67.5% 222 / 0 / 329
Functions: 93.5% 29 / 0 / 31
Branches: 34.8% 272 / 0 / 781

framework/op.cc
Line Branch Exec Source
1 #include "framework/op.h"
2 #include "framework/common/hierkv_local_runtime.h"
3 #include "framework/common/local_shm_op_component.h"
4 #include "framework/common/op_runtime_support.h"
5 #include "framework/common/ps_client_config_adapter.h"
6 #include "ps/client_factory.h"
7 #include "ps/brpc/dist_brpc_ps_client.h"
8 #include "ps/grpc/dist_grpc_ps_client.h"
9 #include "ps/rdma/rdma_ps_client_adapter.h"
10 #include "base/factory.h"
11 #include <algorithm>
12 #include <cctype>
13 #include <cstring>
14 #include <immintrin.h>
15 #include <iostream>
16 #include <stdexcept>
17 #include <vector>
18 #include <unordered_map>
19 #include <filesystem>
20 #include <mutex>
21 #include <memory>
22 #include <numeric>
23 #include <thread>
24 #include <cstdlib>
25 #include <emmintrin.h>
26 #include <string>
27 #include <fstream>
28 #include "base/tensor.h"
29 #include <glog/logging.h>
30 #ifdef ENABLE_PERF_REPORT
31 # include "base/report/report_client.h"
32 #endif
33
34 namespace recstore {
35
36 namespace {
37 24 std::string NormalizeBackendName(std::string backend_name) {
38 24 std::transform(
39 backend_name.begin(),
40 backend_name.end(),
41 backend_name.begin(),
42 136 [](unsigned char c) { return static_cast<char>(std::tolower(c)); });
43 24 return backend_name;
44 }
45
46 58 bool IsReadWriteSuccess(BasePSClient* client, int ret) {
47
3/4
✓ Branch 0 taken 58 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8 times.
✓ Branch 3 taken 50 times.
116 if (dynamic_cast<RDMAPSClientAdapter*>(client) != nullptr ||
48
2/4
✓ Branch 0 taken 58 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
66 dynamic_cast<DistributedGRPCParameterClient*>(client) != nullptr ||
49
5/8
✓ Branch 0 taken 58 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 8 times.
✓ Branch 6 taken 50 times.
✓ Branch 7 taken 8 times.
124 dynamic_cast<DistributedBRPCParameterClient*>(client) != nullptr ||
50
1/2
✓ Branch 0 taken 8 times.
✗ Branch 1 not taken.
8 dynamic_cast<LocalShmPSClient*>(client) != nullptr) {
51 50 return ret == 0;
52 }
53 // Legacy GRPC/BRPC read/write methods return bool-like int values.
54 8 return ret != 0;
55 }
56
57 24 std::string ResolveBackendNameWithHierKV(const json& config) {
58
3/6
✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
✓ Branch 5 taken 24 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 24 times.
✗ Branch 8 not taken.
24 if (config.contains("cache_ps") && config["cache_ps"].contains("ps_type")) {
59 const std::string ps_type =
60
4/8
✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 24 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 24 times.
✗ Branch 11 not taken.
24 NormalizeBackendName(config["cache_ps"]["ps_type"].get<std::string>());
61
3/4
✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 20 times.
✓ Branch 4 taken 4 times.
24 if (IsHierKVBackendName(ps_type)) {
62
1/2
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
20 return ps_type;
63 }
64
2/2
✓ Branch 1 taken 4 times.
✓ Branch 2 taken 20 times.
24 }
65
1/5
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
4 switch (ResolveFrameworkPSClientType(config)) {
66 4 case PSClientType::kGrpc:
67
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 return HasFrameworkDistributedClientConfig(config)
68 ? "distributed_grpc"
69
2/4
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 5 not taken.
8 : "grpc";
70 case PSClientType::kBrpc:
71 return HasFrameworkDistributedClientConfig(config)
72 ? "distributed_brpc"
73 : "brpc";
74 case PSClientType::kRdma:
75 return "rdma";
76 case PSClientType::kLocalShm:
77 return "local_shm";
78 }
79
80 return "unknown";
81 }
82 } // namespace
83
84 70 void validate_keys(const base::RecTensor& keys) {
85
2/2
✓ Branch 1 taken 2 times.
✓ Branch 2 taken 68 times.
70 if (keys.dtype() != base::DataType::UINT64) {
86
2/4
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
2 throw std::invalid_argument("Keys tensor must have dtype UINT64, but got " +
87
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
4 base::DataTypeToString(keys.dtype()));
88 }
89
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 68 times.
68 if (keys.dim() != 1) {
90 throw std::invalid_argument("Keys tensor must be 1-dimensional, but has " +
91 std::to_string(keys.dim()) + " dimensions.");
92 }
93 68 }
94
95 66 void validate_embeddings(const base::RecTensor& embeddings,
96 const std::string& name) {
97
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 66 times.
66 if (embeddings.dtype() != base::DataType::FLOAT32) {
98 throw std::invalid_argument(
99 name + " tensor must have dtype FLOAT32, but got " +
100 base::DataTypeToString(embeddings.dtype()));
101 }
102
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 66 times.
66 if (embeddings.dim() != 2) {
103 throw std::invalid_argument(
104 name + " tensor must be 2-dimensional, but has " +
105 std::to_string(embeddings.dim()) + " dimensions.");
106 }
107 // No fixed embedding dimension check for mock.
108 66 }
109
110 2 void KVClientOp::EmbInit(const base::RecTensor& keys,
111 const base::RecTensor& init_values) {
112 2 EmbWrite(keys, init_values);
113 2 }
114
115 2 void KVClientOp::EmbDelete(const base::RecTensor& keys) {
116
1/2
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
2 throw std::runtime_error("Not impl");
117 }
118 2 bool KVClientOp::EmbExists(const base::RecTensor& keys) {
119
1/2
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
2 throw std::runtime_error("Not impl");
120 }
121
122 2 void KVClientOp::WaitForWrite(uint64_t write_id) {
123
1/2
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
2 throw std::runtime_error("Not impl");
124 }
125 2 void KVClientOp::SaveToFile(const std::string& path) {
126
1/2
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
2 throw std::runtime_error("Not impl");
127 }
128 2 void KVClientOp::LoadFromFile(const std::string& path) {
129
1/2
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
2 throw std::runtime_error("Not impl");
130 }
131
132 2 uint64_t KVClientOp::EmbWriteAsync(const base::RecTensor& keys,
133 const base::RecTensor& values) {
134
1/2
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
2 throw std::runtime_error("Not impl");
135 }
136
137 70 std::shared_ptr<CommonOp> GetKVClientOp() {
138
3/4
✓ Branch 0 taken 6 times.
✓ Branch 1 taken 64 times.
✓ Branch 3 taken 6 times.
✗ Branch 4 not taken.
70 static std::shared_ptr<CommonOp> instance;
139 static std::once_flag once_flag;
140
1/2
✓ Branch 1 taken 70 times.
✗ Branch 2 not taken.
70 std::call_once(once_flag, []() {
141 6 instance = std::make_shared<KVClientOp>();
142 6 });
143 70 return instance;
144 }
145
146 } // namespace recstore
147
148 #ifndef USE_FAKE_KVCLIENT
149
150 namespace recstore {
151
152
1/2
✓ Branch 3 taken 34 times.
✗ Branch 4 not taken.
34 KVClientOp::KVClientOp() {
153
2/2
✓ Branch 0 taken 24 times.
✓ Branch 1 taken 10 times.
34 if (!ps_client_) {
154 try {
155
1/2
✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
24 json config = GetGlobalConfig();
156
1/2
✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
24 ps_backend_name_ = ResolveBackendNameWithHierKV(config);
157
3/4
✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 20 times.
✓ Branch 4 taken 4 times.
24 if (IsHierKVBackendName(ps_backend_name_)) {
158
1/2
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
20 ConfigureLogging();
159
3/6
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 20 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 20 times.
✗ Branch 8 not taken.
20 LOG(INFO) << "Initialized local HierKV backend in KVClientOp.";
160 20 return;
161 }
162 4 bool use_rdma = false;
163 try {
164
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 use_rdma = ResolveFrameworkPSClientType(config) == PSClientType::kRdma;
165 } catch (...) {
166 use_rdma = false;
167 }
168 std::cerr << "[RDMA-DBG] KVClientOp ctor use_rdma="
169
4/8
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
✓ Branch 6 taken 4 times.
✗ Branch 7 not taken.
✓ Branch 9 taken 4 times.
✗ Branch 10 not taken.
4 << (use_rdma ? "true" : "false") << std::endl;
170
171
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
4 if (use_rdma) {
172 std::cerr << "[RDMA-DBG] InitializeRdmaProcessRuntime before "
173 "ConfigureLogging"
174 << std::endl;
175 InitializeRdmaProcessRuntime();
176 ConfigureLogging(false);
177 } else {
178
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 ConfigureLogging();
179 }
180
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 ps_client_holder_ = create_ps_client_from_config(config);
181 4 ps_client_ = ps_client_holder_.get();
182
183
3/6
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 4 times.
✗ Branch 8 not taken.
4 LOG(INFO) << "PS client initialized successfully.";
184
2/4
✓ Branch 1 taken 4 times.
✓ Branch 2 taken 20 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
24 } catch (const std::exception& e) {
185 LOG(ERROR) << "Failed to initialize PS client: " << std::string(e.what());
186 throw;
187 }
188 }
189 }
190
191 BasePSClient* KVClientOp::ps_client_ = nullptr;
192 std::unique_ptr<BasePSClient> KVClientOp::ps_client_holder_;
193
194 void KVClientOp::SetPSConfig(const std::string& host, int port) {
195 if (IsHierKVBackendName(ps_backend_name_)) {
196 LOG(INFO) << "HierKV backend ignores set_ps_config host=" << host
197 << " port=" << port;
198 return;
199 }
200 ps_client_holder_.reset();
201 ps_client_ = nullptr;
202
203 json file_config = GetGlobalConfig();
204 int final_port = port;
205 if (final_port <= 0) {
206 if (file_config.contains("client") &&
207 file_config["client"].contains("port")) {
208 final_port = file_config["client"]["port"].get<int>();
209 } else if (file_config.contains("cache_ps") &&
210 file_config["cache_ps"].contains("servers") &&
211 file_config["cache_ps"]["servers"].is_array() &&
212 !file_config["cache_ps"]["servers"].empty()) {
213 final_port = file_config["cache_ps"]["servers"][0]["port"].get<int>();
214 } else {
215 final_port = 15000;
216 }
217 }
218
219 std::string final_host = host;
220 if (final_host.empty()) {
221 final_host = "127.0.0.1";
222 }
223
224 json config = file_config;
225 config.erase("distributed_client");
226 if (!config.contains("client")) {
227 config["client"] = json::object();
228 }
229 config["client"]["host"] = final_host;
230 config["client"]["port"] = final_port;
231 config["client"]["shard"] = 0;
232
233 ps_client_holder_ = create_ps_client_from_config(config);
234 ps_client_ = ps_client_holder_.get();
235 ps_backend_name_ = ResolveBackendNameWithHierKV(config);
236 LOG(INFO) << "Re-initialized PS client with host=" << final_host
237 << " port=" << final_port;
238 }
239
240 void KVClientOp::SetPSBackend(const std::string& backend) {
241 if (backend.empty()) {
242 throw std::invalid_argument("backend must be non-empty");
243 }
244
245 const std::string normalized_backend = NormalizeBackendName(backend);
246 if (IsHierKVBackendName(normalized_backend)) {
247 ps_client_holder_.reset();
248 ps_client_ = nullptr;
249 ps_backend_name_ = normalized_backend;
250 LOG(INFO) << "Switched KVClientOp backend to local HierKV runtime.";
251 return;
252 }
253
254 json config = GetGlobalConfig();
255 if (!config.contains("cache_ps")) {
256 config["cache_ps"] = json::object();
257 }
258 config["cache_ps"]["ps_type"] = NormalizePSType(backend);
259
260 ps_client_holder_.reset();
261 ps_client_ = nullptr;
262 ps_client_holder_ = create_ps_client_from_config(config);
263 ps_client_ = ps_client_holder_.get();
264 ps_backend_name_ = ResolveBackendNameWithHierKV(config);
265 LOG(INFO) << "Re-initialized PS client with backend=" << ps_backend_name_;
266 }
267
268 2 std::string KVClientOp::CurrentPSBackend() const { return ps_backend_name_; }
269
270 48 void KVClientOp::EmbRead(const RecTensor& keys, RecTensor& values) {
271
3/4
✓ Branch 1 taken 48 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 10 times.
✓ Branch 4 taken 38 times.
48 if (IsHierKVBackendName(ps_backend_name_)) {
272
3/4
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 6 times.
✓ Branch 5 taken 4 times.
10 GetHierKVLocalRuntime().Read(keys, values);
273 6 return;
274 }
275
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 38 times.
38 if (ps_client_ == nullptr) {
276 throw std::runtime_error("PS client is not initialized. Please call "
277 "KVClientOp::SetPSClient() first.");
278 }
279
280 # ifdef ENABLE_PERF_REPORT
281 auto start_time = std::chrono::high_resolution_clock::now();
282 double start_us =
283 std::chrono::duration_cast<std::chrono::microseconds>(
284 start_time.time_since_epoch())
285 .count();
286 std::string report_id =
287 "op::EmbRead|" + std::to_string(static_cast<uint64_t>(start_us));
288 std::string unique_id =
289 "embread_debug|" + std::to_string(static_cast<uint64_t>(start_us));
290 # endif
291
292
6/12
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 38 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 38 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 38 times.
✗ Branch 11 not taken.
✓ Branch 13 taken 38 times.
✗ Branch 14 not taken.
✓ Branch 16 taken 38 times.
✗ Branch 17 not taken.
76 LOG(INFO) << "EmbRead: keys.shape=" << keys.shape(0) << ", values.shape=["
293
6/12
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 38 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 38 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 38 times.
✗ Branch 11 not taken.
✓ Branch 13 taken 38 times.
✗ Branch 14 not taken.
✓ Branch 16 taken 38 times.
✗ Branch 17 not taken.
38 << values.shape(0) << ", " << values.shape(1) << "]";
294
5/10
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 38 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 38 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 38 times.
✗ Branch 11 not taken.
✓ Branch 13 taken 38 times.
✗ Branch 14 not taken.
76 LOG(INFO) << "EmbRead: keys.data=" << keys.data_as<uint64_t>()
295
3/6
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 38 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 38 times.
✗ Branch 8 not taken.
38 << ", values.data=" << values.data_as<float>();
296
2/4
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 38 times.
✗ Branch 4 not taken.
38 if (keys.shape(0) > 0) {
297
1/2
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
38 std::ostringstream oss;
298
1/2
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
38 oss << "EmbRead: keys start with: ";
299
3/4
✓ Branch 1 taken 142 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 104 times.
✓ Branch 5 taken 38 times.
142 for (int i = 0; i < std::min((int64_t)10, keys.shape(0)); ++i)
300
3/6
✓ Branch 1 taken 104 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 104 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 104 times.
✗ Branch 8 not taken.
104 oss << keys.data_as<uint64_t>()[i] << ", ";
301
4/8
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 38 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 38 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 38 times.
✗ Branch 11 not taken.
38 LOG(INFO) << oss.str();
302 38 }
303
2/4
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 38 times.
✗ Branch 4 not taken.
38 if (values.shape(0) > 0) {
304
1/2
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
38 std::ostringstream oss;
305
1/2
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
38 oss << "EmbRead: values start with: ";
306
3/4
✓ Branch 1 taken 140 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 102 times.
✓ Branch 5 taken 38 times.
140 for (int i = 0; i < std::min((int64_t)10, values.shape(0)); ++i) {
307
1/2
✓ Branch 1 taken 102 times.
✗ Branch 2 not taken.
102 oss << "[";
308
3/4
✓ Branch 1 taken 1052 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 950 times.
✓ Branch 5 taken 102 times.
1052 for (int j = 0; j < std::min((int64_t)10, values.shape(1)); ++j) {
309
4/8
✓ Branch 1 taken 950 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 950 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 950 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 950 times.
✗ Branch 11 not taken.
950 oss << values.data_as<float>()[i * values.shape(1) + j] << ", ";
310 }
311
1/2
✓ Branch 1 taken 102 times.
✗ Branch 2 not taken.
102 oss << "] ";
312 }
313
4/8
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 38 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 38 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 38 times.
✗ Branch 11 not taken.
38 LOG(INFO) << oss.str();
314 38 }
315
1/2
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
38 validate_keys(keys);
316
2/4
✓ Branch 2 taken 38 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 38 times.
✗ Branch 6 not taken.
38 validate_embeddings(values, "Values");
317
318
1/2
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
38 const int64_t L = keys.shape(0);
319
3/4
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
✓ Branch 4 taken 36 times.
38 if (values.shape(0) != L) {
320
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 throw std::invalid_argument(
321
3/6
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 8 not taken.
4 "Dimension mismatch: Keys has length " + std::to_string(L) +
322
3/6
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 9 not taken.
8 " but values has length " + std::to_string(values.shape(0)));
323 }
324
1/2
✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
36 const uint64_t* keys_data = keys.data_as<uint64_t>();
325 36 base::ConstArray<uint64_t> keys_array(keys_data, L);
326
1/2
✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
36 float* values_data = values.data_as<float>();
327
328
1/2
✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
36 const int64_t D = values.shape(1);
329 36 const size_t total = static_cast<size_t>(L) * static_cast<size_t>(D);
330
1/2
✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
36 std::fill_n(values_data, total, 0.0f);
331
332
1/2
✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
36 int ret = ps_client_->GetParameter(keys_array, values_data);
333
2/2
✓ Branch 1 taken 2 times.
✓ Branch 2 taken 34 times.
36 if (!IsReadWriteSuccess(ps_client_, ret)) {
334
1/2
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
2 throw std::runtime_error("Failed to read embeddings from PS client.");
335 }
336
337 # ifdef ENABLE_PERF_REPORT
338 auto end_time = std::chrono::high_resolution_clock::now();
339 auto duration =
340 std::chrono::duration_cast<std::chrono::microseconds>(
341 end_time - start_time)
342 .count();
343 std::string op_latency_key =
344 "EmbRead|" + std::to_string(static_cast<uint64_t>(start_us));
345 report("op_latency",
346 op_latency_key.c_str(),
347 "recstore_us",
348 static_cast<double>(duration));
349
350 report("embread_stages",
351 report_id.c_str(),
352 "duration_us",
353 static_cast<double>(duration));
354
355 report("embread_stages",
356 report_id.c_str(),
357 "request_size",
358 static_cast<double>(keys.shape(0)));
359
360 FlameGraphData op_data = {
361 "op::EmbRead",
362 start_us,
363 0, // level
364 static_cast<double>(duration),
365 static_cast<double>(duration)};
366 report_flame_graph("emb_read_flame_map", unique_id.c_str(), op_data);
367 # endif
368 }
369
370 4 void KVClientOp::EmbUpdate(const base::RecTensor& keys,
371 const base::RecTensor& grads) {
372
2/4
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 4 times.
8 EmbUpdate("default", keys, grads);
373 }
374
375 10 void KVClientOp::EmbUpdate(const std::string& table_name,
376 const base::RecTensor& keys,
377 const base::RecTensor& grads) {
378
3/4
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 4 times.
✓ Branch 4 taken 6 times.
10 if (IsHierKVBackendName(ps_backend_name_)) {
379
3/4
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 2 times.
✓ Branch 5 taken 2 times.
4 GetHierKVLocalRuntime().Update(table_name, keys, grads);
380 2 return;
381 }
382
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
6 if (ps_client_ == nullptr) {
383 throw std::runtime_error("PS client is not initialized. Please call "
384 "KVClientOp::SetPSClient() first.");
385 }
386
387 # ifdef ENABLE_PERF_REPORT
388 auto start_time = std::chrono::high_resolution_clock::now();
389 const uint64_t trace_id = static_cast<uint64_t>(
390 std::chrono::duration_cast<std::chrono::microseconds>(
391 start_time.time_since_epoch())
392 .count());
393 struct TraceGuard {
394 explicit TraceGuard(uint64_t new_trace_id)
395 : previous_trace_id_(recstore::g_trace_id) {
396 recstore::g_trace_id = new_trace_id;
397 }
398 ~TraceGuard() { recstore::g_trace_id = previous_trace_id_; }
399 uint64_t previous_trace_id_;
400 } trace_guard(trace_id);
401 # endif
402
403 6 int64_t validate_done_us = 0;
404
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 validate_keys(keys);
405
2/4
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 6 times.
✗ Branch 6 not taken.
6 validate_embeddings(grads, "Grads");
406
407
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 const int64_t L = keys.shape(0);
408
2/4
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
6 if (grads.shape(0) != L) {
409 throw std::invalid_argument(
410 "Dimension mismatch: Keys has length " + std::to_string(L) +
411 " but grads has length " + std::to_string(grads.shape(0)));
412 }
413
414
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 const int64_t D = grads.shape(1);
415
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
6 if (D <= 0) {
416 throw std::invalid_argument(
417 "Invalid grad dimension D: " + std::to_string(D));
418 }
419
420 # ifdef ENABLE_PERF_REPORT
421 auto validate_done_time = std::chrono::high_resolution_clock::now();
422 validate_done_us =
423 std::chrono::duration_cast<std::chrono::microseconds>(
424 validate_done_time - start_time)
425 .count();
426 # endif
427
428
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 const uint64_t* keys_data = keys.data_as<uint64_t>();
429 6 base::ConstArray<uint64_t> keys_array(keys_data, L);
430
431
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 const float* grads_data = grads.data_as<float>();
432 int ret =
433
1/2
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
6 ps_client_->UpdateParameterFlat(table_name, keys_array, grads_data, L, D);
434
2/2
✓ Branch 0 taken 2 times.
✓ Branch 1 taken 4 times.
6 if (ret != 0) {
435
1/2
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
2 throw std::runtime_error("Failed to update embeddings via PS client.");
436 }
437
438 # ifdef ENABLE_PERF_REPORT
439 auto end_time = std::chrono::high_resolution_clock::now();
440 auto duration =
441 std::chrono::duration_cast<std::chrono::microseconds>(
442 end_time - start_time)
443 .count();
444 double start_us =
445 std::chrono::duration_cast<std::chrono::microseconds>(
446 start_time.time_since_epoch())
447 .count();
448 std::string op_latency_key =
449 "EmbUpdate|" + std::to_string(static_cast<uint64_t>(start_us));
450 report("op_latency",
451 op_latency_key.c_str(),
452 "recstore_us",
453 static_cast<double>(duration));
454
455 std::string update_stage_id =
456 "op_client::EmbUpdate|" + std::to_string(trace_id);
457 report("embupdate_stages",
458 update_stage_id.c_str(),
459 "op_total_us",
460 static_cast<double>(duration));
461 report("embupdate_stages",
462 update_stage_id.c_str(),
463 "op_validate_us",
464 static_cast<double>(validate_done_us));
465 report("embupdate_stages",
466 update_stage_id.c_str(),
467 "request_size",
468 static_cast<double>(L));
469 report("embupdate_stages",
470 update_stage_id.c_str(),
471 "embedding_dim",
472 static_cast<double>(D));
473 # endif
474 }
475
476 18 bool KVClientOp::InitEmbeddingTable(const std::string& table_name,
477 const EmbeddingTableConfig& config) {
478
2/2
✓ Branch 1 taken 4 times.
✓ Branch 2 taken 14 times.
18 if (IsHierKVBackendName(ps_backend_name_)) {
479 4 return GetHierKVLocalRuntime().InitEmbeddingTable(table_name, config);
480 }
481
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
14 if (ps_client_ == nullptr) {
482 throw std::runtime_error("PS client is not initialized. Please call "
483 "KVClientOp::SetPSClient() first.");
484 }
485
486 # ifdef ENABLE_PERF_REPORT
487 auto start_time = std::chrono::high_resolution_clock::now();
488 # endif
489 14 int ret = ps_client_->InitEmbeddingTable(table_name, config);
490 # ifdef ENABLE_PERF_REPORT
491 auto end_time = std::chrono::high_resolution_clock::now();
492 auto duration =
493 std::chrono::duration_cast<std::chrono::microseconds>(
494 end_time - start_time)
495 .count();
496 double start_us =
497 std::chrono::duration_cast<std::chrono::microseconds>(
498 start_time.time_since_epoch())
499 .count();
500 std::string op_latency_key =
501 "InitEmbeddingTable|" + std::to_string(static_cast<uint64_t>(start_us));
502 // report(table_name, key, metric_name, value)
503 report("op_latency",
504 op_latency_key.c_str(),
505 "recstore_us",
506 static_cast<double>(duration));
507 # endif
508 14 return ret == 0;
509 }
510
511 36 void KVClientOp::EmbWrite(const RecTensor& keys, const RecTensor& values) {
512
3/4
✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 14 times.
✓ Branch 4 taken 22 times.
36 if (IsHierKVBackendName(ps_backend_name_)) {
513
3/4
✓ Branch 1 taken 14 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 6 times.
✓ Branch 5 taken 8 times.
14 GetHierKVLocalRuntime().Write(keys, values);
514 6 return;
515 }
516
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 22 times.
22 if (ps_client_ == nullptr) {
517 throw std::runtime_error("PS client is not initialized. Please call "
518 "KVClientOp::SetPSClient() first.");
519 }
520
521 # ifdef ENABLE_PERF_REPORT
522 auto start_time = std::chrono::high_resolution_clock::now();
523 # endif
524
525
1/2
✓ Branch 1 taken 22 times.
✗ Branch 2 not taken.
22 validate_keys(keys);
526
2/4
✓ Branch 2 taken 22 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 22 times.
✗ Branch 6 not taken.
22 validate_embeddings(values, "Values");
527
528
1/2
✓ Branch 1 taken 22 times.
✗ Branch 2 not taken.
22 const int64_t L = keys.shape(0);
529 22 const auto& values_shape = values.shape();
530
2/4
✓ Branch 1 taken 22 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 22 times.
22 if (values.shape(0) != L) {
531 throw std::invalid_argument(
532 "Dimension mismatch: Keys has length " + std::to_string(L) +
533 " but values has length " + std::to_string(values.shape(0)));
534 }
535
1/2
✓ Branch 1 taken 22 times.
✗ Branch 2 not taken.
22 const int64_t D = values.shape(1);
536
537
1/2
✓ Branch 1 taken 22 times.
✗ Branch 2 not taken.
22 const uint64_t* keys_data = keys.data_as<uint64_t>();
538 22 base::ConstArray<uint64_t> keys_array(keys_data, L);
539
1/2
✓ Branch 1 taken 22 times.
✗ Branch 2 not taken.
22 const float* values_data = values.data_as<float>();
540
541 22 const int64_t total_values = L * D;
542
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 22 times.
22 if (values_shape[0] * values_shape[1] != total_values) {
543 throw std::invalid_argument(
544 "Values total elements mismatch: expected " +
545 std::to_string(total_values) + ", but got " +
546 std::to_string(values_shape[0] * values_shape[1]));
547 }
548
549
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 22 times.
22 if (D <= 0) {
550 throw std::invalid_argument(
551 "Invalid embedding dimension D: " + std::to_string(D));
552 }
553
554 22 std::vector<std::vector<float>> values_vector;
555
1/2
✓ Branch 1 taken 22 times.
✗ Branch 2 not taken.
22 values_vector.reserve(L);
556
2/2
✓ Branch 0 taken 600 times.
✓ Branch 1 taken 22 times.
622 for (int64_t i = 0; i < L; ++i) {
557
1/2
✓ Branch 2 taken 600 times.
✗ Branch 3 not taken.
600 std::vector<float> row(D);
558 600 std::memcpy(row.data(), values_data + i * D, D * sizeof(float));
559 600 asm volatile("" ::: "memory");
560 _mm_mfence();
561
1/2
✓ Branch 2 taken 600 times.
✗ Branch 3 not taken.
600 values_vector.push_back(std::move(row));
562 600 }
563
564
3/6
✓ Branch 1 taken 22 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 22 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 22 times.
✗ Branch 8 not taken.
22 LOG(INFO) << "=== Keys Array Info ===";
565
4/8
✓ Branch 1 taken 22 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 22 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 22 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 22 times.
✗ Branch 11 not taken.
22 LOG(INFO) << "Keys size: " << L;
566
1/2
✓ Branch 0 taken 22 times.
✗ Branch 1 not taken.
22 if (L > 0) {
567
1/2
✓ Branch 1 taken 22 times.
✗ Branch 2 not taken.
22 std::ostringstream keys_stream;
568
1/2
✓ Branch 1 taken 22 times.
✗ Branch 2 not taken.
22 keys_stream << "First 3 keys: ";
569
2/2
✓ Branch 1 taken 62 times.
✓ Branch 2 taken 22 times.
84 for (int64_t i = 0; i < std::min(L, static_cast<int64_t>(3)); ++i) {
570
2/4
✓ Branch 2 taken 62 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 62 times.
✗ Branch 6 not taken.
62 keys_stream << keys_array[i] << " ";
571 }
572
4/8
✓ Branch 1 taken 22 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 22 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 22 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 22 times.
✗ Branch 11 not taken.
22 LOG(INFO) << keys_stream.str();
573 22 }
574
575
3/6
✓ Branch 1 taken 22 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 22 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 22 times.
✗ Branch 8 not taken.
22 LOG(INFO) << "=== Values Vector Info ===";
576
4/8
✓ Branch 1 taken 22 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 22 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 22 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 22 times.
✗ Branch 11 not taken.
22 LOG(INFO) << "Values total elements: " << total_values;
577
4/8
✓ Branch 1 taken 22 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 22 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 22 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 22 times.
✗ Branch 11 not taken.
22 LOG(INFO) << "Embedding dimension D: " << D;
578
2/4
✓ Branch 0 taken 22 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 22 times.
✗ Branch 3 not taken.
22 if (L > 0 && D > 0) {
579
1/2
✓ Branch 1 taken 22 times.
✗ Branch 2 not taken.
22 std::ostringstream values_stream;
580
1/2
✓ Branch 1 taken 22 times.
✗ Branch 2 not taken.
22 values_stream << "First 3 embeddings (each first 3 items): ";
581
2/2
✓ Branch 1 taken 62 times.
✓ Branch 2 taken 22 times.
84 for (int64_t i = 0; i < std::min(L, static_cast<int64_t>(3)); ++i) {
582
1/2
✓ Branch 1 taken 62 times.
✗ Branch 2 not taken.
62 values_stream << "[";
583
2/2
✓ Branch 1 taken 186 times.
✓ Branch 2 taken 62 times.
248 for (int64_t j = 0; j < std::min(D, static_cast<int64_t>(3)); ++j) {
584
2/4
✓ Branch 3 taken 186 times.
✗ Branch 4 not taken.
✓ Branch 6 taken 186 times.
✗ Branch 7 not taken.
186 values_stream << values_vector[i][j] << " ";
585 }
586
1/2
✓ Branch 1 taken 62 times.
✗ Branch 2 not taken.
62 values_stream << "] ";
587 }
588
4/8
✓ Branch 1 taken 22 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 22 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 22 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 22 times.
✗ Branch 11 not taken.
22 LOG(INFO) << values_stream.str();
589 22 }
590
591
1/2
✓ Branch 1 taken 22 times.
✗ Branch 2 not taken.
22 int ret = ps_client_->PutParameter(keys_array, values_vector);
592
2/2
✓ Branch 1 taken 2 times.
✓ Branch 2 taken 20 times.
22 if (!IsReadWriteSuccess(ps_client_, ret)) {
593
1/2
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
2 throw std::runtime_error("Failed to write embeddings to PS client.");
594 }
595
596 # ifdef ENABLE_PERF_REPORT
597 auto end_time = std::chrono::high_resolution_clock::now();
598 auto duration =
599 std::chrono::duration_cast<std::chrono::microseconds>(
600 end_time - start_time)
601 .count();
602 double start_us =
603 std::chrono::duration_cast<std::chrono::microseconds>(
604 start_time.time_since_epoch())
605 .count();
606 std::string op_latency_key =
607 "EmbWrite|" + std::to_string(static_cast<uint64_t>(start_us));
608 report("op_latency",
609 op_latency_key.c_str(),
610 "recstore_us",
611 static_cast<double>(duration));
612 # endif
613 22 }
614
615 4 void KVClientOp::EmbInit(const base::RecTensor& keys,
616 const InitStrategy& strategy) {
617 4 validate_keys(keys);
618 2 }
619
620 uint64_t
621 8 KVClientOp::EmbPrefetch(const base::RecTensor& keys, const RecTensor& values) {
622
3/4
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 4 times.
✓ Branch 4 taken 4 times.
8 if (IsHierKVBackendName(ps_backend_name_)) {
623
2/4
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 5 not taken.
4 int64_t embedding_dim = values.dim() == 2 ? values.shape(1) : -1;
624
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
4 if (embedding_dim <= 0) {
625 embedding_dim = GetHierKVLocalRuntime().DefaultEmbeddingDim();
626 }
627
2/4
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 5 not taken.
4 return GetHierKVLocalRuntime().Prefetch(keys, embedding_dim);
628 }
629
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 const uint64_t* keys_data = keys.data_as<uint64_t>();
630
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 int64_t L = keys.shape(0);
631 4 base::ConstArray<uint64_t> keys_array(keys_data, L);
632
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 return ps_client_->PrefetchParameter(keys_array);
633 }
634
635 6 bool KVClientOp::IsPrefetchDone(uint64_t prefetch_id) {
636
2/2
✓ Branch 1 taken 4 times.
✓ Branch 2 taken 2 times.
6 if (IsHierKVBackendName(ps_backend_name_)) {
637 4 return GetHierKVLocalRuntime().IsPrefetchDone(prefetch_id);
638 }
639 2 return ps_client_->IsPrefetchDone(prefetch_id);
640 }
641
642 6 void KVClientOp::WaitForPrefetch(uint64_t prefetch_id) {
643
2/2
✓ Branch 1 taken 2 times.
✓ Branch 2 taken 4 times.
6 if (IsHierKVBackendName(ps_backend_name_)) {
644 2 GetHierKVLocalRuntime().WaitForPrefetch(prefetch_id);
645 2 return;
646 }
647 4 ps_client_->WaitForPrefetch(prefetch_id);
648 }
649
650 4 void KVClientOp::GetPretchResult(uint64_t prefetch_id,
651 std::vector<std::vector<float>>* values) {
652
2/2
✓ Branch 1 taken 2 times.
✓ Branch 2 taken 2 times.
4 if (IsHierKVBackendName(ps_backend_name_)) {
653 2 GetHierKVLocalRuntime().ConsumePrefetch(prefetch_id, values);
654 2 return;
655 }
656 2 ps_client_->GetPrefetchResult(prefetch_id, values);
657 }
658
659 6 void KVClientOp::GetPretchResultFlat(
660 uint64_t prefetch_id,
661 std::vector<float>* values,
662 int64_t* num_rows,
663 int64_t embedding_dim) {
664
2/2
✓ Branch 1 taken 2 times.
✓ Branch 2 taken 4 times.
6 if (IsHierKVBackendName(ps_backend_name_)) {
665 2 GetHierKVLocalRuntime().ConsumePrefetchFlat(
666 prefetch_id, values, num_rows, embedding_dim);
667 2 return;
668 }
669 4 ps_client_->GetPrefetchResultFlat(
670 prefetch_id, values, num_rows, embedding_dim);
671 }
672
673 2 bool KVClientOp::IsWriteDone(uint64_t write_id) {
674 // return ps_client_->IsWriteDone(write_id);
675
1/2
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
2 throw std::runtime_error("Not impl");
676 }
677
678 namespace testing {} // namespace testing
679
680 } // namespace recstore
681
682 #else
683
684 # include "common/op_mock.cc"
685
686 #endif // USE_FAKE_KVCLIENT
687