GCC Code Coverage Report


Directory: src/
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 69.5% 207 / 0 / 298
Functions: 56.2% 9 / 0 / 16
Branches: 39.4% 182 / 0 / 462

storage/kv_engine/engine_composite.h
Line Branch Exec Source
1 #pragma once
2
3 #include <algorithm>
4 #include <chrono>
5 #include <cstring>
6 #include <memory>
7 #include <stdexcept>
8 #include <unordered_set>
9 #include <vector>
10
11 #include "base/factory.h"
12 #include "storage/index/dram/extendible_hash_index.h"
13 #include "storage/index/dram/pet_hash_index.h"
14 #include "storage/index/dram/unordered_map_index.h"
15 #include "storage/kv_engine/base_kv.h"
16 #include "storage/value_store/dram_value_store.h"
17 #include "storage/value_store/hybrid_value_store.h"
18 #include "storage/value_store/ssd_value_store.h"
19
20 class KVEngineComposite : public BaseKV {
21 public:
22 KVEngineComposite(std::unique_ptr<Index> index,
23 std::unique_ptr<ValueStore> value_store,
24 int num_threads = 0)
25 : BaseKV(BaseKVConfig{}),
26 index_(std::move(index)),
27 value_store_(std::move(value_store)),
28 num_threads_(num_threads) {}
29
30 1032 explicit KVEngineComposite(const BaseKVConfig& config) : BaseKV(config) {
31 1032 config_ = config;
32 1032 const auto& j = config.json_config_;
33
3/6
✓ Branch 1 taken 1032 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 1032 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 1032 times.
✗ Branch 8 not taken.
1032 const std::string index_type = j.at("index").at("type").get<std::string>();
34
3/6
✓ Branch 1 taken 1032 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 1032 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 1032 times.
✗ Branch 8 not taken.
1032 const std::string value_type = j.at("value").at("type").get<std::string>();
35 using IF = base::Factory<Index, const BaseKVConfig&>;
36 using VF = base::Factory<ValueStore, const BaseKVConfig&>;
37
1/2
✓ Branch 1 taken 1032 times.
✗ Branch 2 not taken.
1032 index_.reset(IF::NewInstance(index_type, config));
38
2/2
✓ Branch 1 taken 1030 times.
✓ Branch 2 taken 2 times.
1032 value_store_.reset(VF::NewInstance(value_type, config));
39 1030 num_threads_ = config.num_threads_;
40 1030 default_value_size_hint_ =
41
2/4
✓ Branch 1 taken 1030 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 1030 times.
✗ Branch 5 not taken.
1030 j.at("value").value("default_value_size_hint", 0);
42
3/6
✓ Branch 1 taken 1030 times.
✗ Branch 2 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 1030 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1030 times.
1030 if (!index_ || !value_store_) {
43 throw std::runtime_error("failed to create KVEngine components");
44 }
45 1042 }
46
47 201374 void Get(const uint64_t key, std::string& value, unsigned tid) override {
48 (void)tid;
49 201374 Value_t handle = kValueHandleNone;
50
1/2
✓ Branch 2 taken 201374 times.
✗ Branch 3 not taken.
201374 index_->Get(key, handle);
51
2/2
✓ Branch 0 taken 6280 times.
✓ Branch 1 taken 195094 times.
201374 if (handle == kValueHandleNone) {
52 6280 value.clear();
53 149054 return;
54 }
55
3/4
✓ Branch 2 taken 195094 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 142774 times.
✓ Branch 5 taken 52320 times.
195094 if (const char* ptr = value_store_->DirectPtr(handle)) {
56
2/4
✓ Branch 2 taken 142774 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 142774 times.
✗ Branch 6 not taken.
142774 value.resize(value_store_->SlotCapacity(handle));
57 142774 std::memcpy(value.data(), ptr, value.size());
58 142774 return;
59 }
60
2/4
✓ Branch 2 taken 52320 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 52320 times.
✗ Branch 6 not taken.
52320 value.resize(value_store_->SlotCapacity(handle));
61 const size_t actual =
62
1/2
✓ Branch 4 taken 52320 times.
✗ Branch 5 not taken.
52320 value_store_->Read(handle, value.data(), value.size());
63
1/2
✓ Branch 1 taken 52320 times.
✗ Branch 2 not taken.
52320 value.resize(actual);
64 }
65
66 bool Exists(const uint64_t key, unsigned tid) override {
67 (void)tid;
68 Value_t handle = kValueHandleNone;
69 index_->Get(key, handle);
70 return handle != kValueHandleNone;
71 }
72
73 354012 void Put(const uint64_t key,
74 const std::string_view& value,
75 unsigned tid) override {
76 354012 PutInternal(key, value.data(), value.size(), tid, true);
77 354012 }
78
79 292 void BatchPut(base::ConstArray<uint64_t> keys,
80 std::vector<base::ConstArray<float>>* values,
81 unsigned tid) override {
82
3/6
✓ Branch 0 taken 292 times.
✗ Branch 1 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 292 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 292 times.
292 if (values == nullptr || keys.Size() != static_cast<int>(values->size())) {
83 LOG(FATAL) << "KVEngine::BatchPut size mismatch";
84 }
85 (void)tid;
86
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 292 times.
292 if (keys.Size() == 0) {
87 return;
88 }
89
90 292 std::unordered_set<uint64_t> seen_keys;
91
1/2
✓ Branch 2 taken 292 times.
✗ Branch 3 not taken.
292 seen_keys.reserve(static_cast<size_t>(keys.Size()));
92 292 bool has_duplicate_key = false;
93
2/2
✓ Branch 1 taken 740 times.
✓ Branch 2 taken 292 times.
1032 for (int i = 0; i < keys.Size(); ++i) {
94
2/4
✓ Branch 2 taken 740 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 740 times.
740 if (!seen_keys.insert(keys[i]).second) {
95 has_duplicate_key = true;
96 break;
97 }
98 }
99
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 292 times.
292 if (has_duplicate_key) {
100 for (int i = 0; i < keys.Size(); ++i) {
101 const auto& item = (*values)[i];
102 PutInternal(keys[i],
103 item.Data(),
104 static_cast<size_t>(item.Size()) * sizeof(float),
105 tid,
106 false);
107 }
108 return;
109 }
110
111 struct PutItem {
112 uint64_t key = 0;
113 ValueStore::WriteSpec spec{};
114 };
115 292 std::vector<PutItem> items;
116
1/2
✓ Branch 2 taken 292 times.
✗ Branch 3 not taken.
292 items.reserve(static_cast<size_t>(keys.Size()));
117
118
2/2
✓ Branch 1 taken 740 times.
✓ Branch 2 taken 292 times.
1032 for (int i = 0; i < keys.Size(); ++i) {
119 740 const auto& item = (*values)[i];
120 740 const void* data = item.Data();
121 740 const size_t size = static_cast<size_t>(item.Size()) * sizeof(float);
122
1/2
✓ Branch 2 taken 740 times.
✗ Branch 3 not taken.
740 items.push_back(PutItem{keys[i], ValueStore::WriteSpec{data, size}});
123 }
124
125 292 std::vector<ValueStore::WriteSpec> specs;
126
1/2
✓ Branch 2 taken 292 times.
✗ Branch 3 not taken.
292 specs.reserve(items.size());
127
2/2
✓ Branch 5 taken 740 times.
✓ Branch 6 taken 292 times.
1032 for (const auto& item : items) {
128
1/2
✓ Branch 1 taken 740 times.
✗ Branch 2 not taken.
740 specs.push_back(item.spec);
129 }
130
1/2
✓ Branch 2 taken 292 times.
✗ Branch 3 not taken.
292 const auto new_handles = value_store_->BatchAllocAndWrite(specs);
131
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 292 times.
292 if (new_handles.size() != items.size()) {
132 LOG(FATAL) << "KVEngine::BatchPut allocation result size mismatch";
133 }
134
2/2
✓ Branch 1 taken 740 times.
✓ Branch 2 taken 292 times.
1032 for (size_t i = 0; i < items.size(); ++i) {
135
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 740 times.
740 if (new_handles[i] == kValueHandleNone) {
136 LOG(FATAL) << "KVEngine batch value allocation failed, key="
137 << items[i].key << " size=" << items[i].spec.size;
138 }
139 }
140
141
2/2
✓ Branch 1 taken 740 times.
✓ Branch 2 taken 292 times.
1032 for (size_t i = 0; i < items.size(); ++i) {
142
1/2
✓ Branch 4 taken 740 times.
✗ Branch 5 not taken.
740 Value_t old_handle = index_->Put(items[i].key, new_handles[i], tid);
143
2/2
✓ Branch 0 taken 216 times.
✓ Branch 1 taken 524 times.
740 if (old_handle != kValueHandleNone) {
144
1/2
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
216 value_store_->Retire(old_handle);
145 }
146 }
147
1/2
✓ Branch 4 taken 292 times.
✗ Branch 5 not taken.
292 }
148
149 1536 void BatchGet(base::ConstArray<uint64_t> keys,
150 std::vector<base::ConstArray<float>>* values,
151 unsigned tid) override {
152 (void)tid;
153
1/2
✓ Branch 2 taken 1536 times.
✗ Branch 3 not taken.
1536 values->resize(keys.Size());
154
2/2
✓ Branch 0 taken 252 times.
✓ Branch 1 taken 1284 times.
1536 thread_local std::vector<Value_t> handles;
155
2/2
✓ Branch 0 taken 252 times.
✓ Branch 1 taken 1284 times.
1536 thread_local std::vector<std::vector<float>> buffers;
156
1/2
✓ Branch 2 taken 1536 times.
✗ Branch 3 not taken.
1536 handles.assign(keys.Size(), kValueHandleNone);
157 1536 buffers.clear();
158
1/2
✓ Branch 2 taken 1536 times.
✗ Branch 3 not taken.
1536 buffers.resize(keys.Size());
159
160
1/2
✓ Branch 1 taken 1536 times.
✗ Branch 2 not taken.
1536 if (keys.Size() > 0) {
161
1/2
✓ Branch 3 taken 1536 times.
✗ Branch 4 not taken.
1536 index_->BatchGet(keys, handles.data(), tid);
162 }
163 1536 std::vector<uint64_t> batch_handles;
164 1536 std::vector<size_t> batch_indices;
165
1/2
✓ Branch 2 taken 1536 times.
✗ Branch 3 not taken.
1536 batch_handles.reserve(static_cast<size_t>(keys.Size()));
166
1/2
✓ Branch 2 taken 1536 times.
✗ Branch 3 not taken.
1536 batch_indices.reserve(static_cast<size_t>(keys.Size()));
167
2/2
✓ Branch 1 taken 67426 times.
✓ Branch 2 taken 1536 times.
68962 for (int i = 0; i < keys.Size(); ++i) {
168
2/2
✓ Branch 1 taken 324 times.
✓ Branch 2 taken 67102 times.
67426 if (handles[i] == kValueHandleNone) {
169 324 (*values)[i] = base::ConstArray<float>();
170 324 continue;
171 }
172
3/4
✓ Branch 3 taken 67102 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 45542 times.
✓ Branch 6 taken 21560 times.
67102 if (const char* ptr = value_store_->DirectPtr(handles[i])) {
173
1/2
✓ Branch 3 taken 45542 times.
✗ Branch 4 not taken.
45542 const size_t bytes = value_store_->SlotCapacity(handles[i]);
174 45542 (*values)[i] = base::ConstArray<float>(
175 reinterpret_cast<float*>(const_cast<char*>(ptr)),
176 45542 bytes / sizeof(float));
177 45542 continue;
178 45542 }
179
1/2
✓ Branch 2 taken 21560 times.
✗ Branch 3 not taken.
21560 batch_handles.push_back(handles[i]);
180
1/2
✓ Branch 1 taken 21560 times.
✗ Branch 2 not taken.
21560 batch_indices.push_back(static_cast<size_t>(i));
181 }
182
183 1536 std::vector<ValueStore::ReadResult> batch_results;
184
2/2
✓ Branch 1 taken 410 times.
✓ Branch 2 taken 1126 times.
1536 if (!batch_handles.empty()) {
185
1/2
✓ Branch 2 taken 410 times.
✗ Branch 3 not taken.
410 value_store_->BatchRead(batch_handles, batch_results);
186
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 410 times.
410 if (batch_results.size() != batch_indices.size()) {
187 LOG(FATAL) << "KVEngine::BatchGet read result size mismatch";
188 }
189
2/2
✓ Branch 1 taken 21560 times.
✓ Branch 2 taken 410 times.
21970 for (size_t i = 0; i < batch_indices.size(); ++i) {
190 21560 const size_t idx = batch_indices[i];
191 21560 const auto& result = batch_results[i];
192
1/2
✓ Branch 3 taken 21560 times.
✗ Branch 4 not taken.
21560 buffers[idx].resize(result.data.size() / sizeof(float));
193
1/2
✓ Branch 1 taken 21560 times.
✗ Branch 2 not taken.
21560 if (!result.data.empty()) {
194 43120 std::memcpy(
195 21560 buffers[idx].data(), result.data.data(), result.data.size());
196 }
197 21560 (*values)[idx] =
198 43120 base::ConstArray<float>(buffers[idx].data(), buffers[idx].size());
199 }
200 }
201 1536 }
202
203 122 bool BatchGetFlat(base::ConstArray<uint64_t> keys,
204 float* values,
205 int64_t num_rows,
206 int64_t embedding_dim,
207 unsigned tid,
208 BatchGetFlatStats* stats = nullptr) override {
209
4/8
✓ Branch 0 taken 122 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 122 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 122 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 122 times.
244 if (values == nullptr || num_rows < 0 || embedding_dim <= 0 ||
210
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 122 times.
122 keys.Size() != static_cast<size_t>(num_rows)) {
211 return false;
212 }
213 122 const size_t row_bytes = static_cast<size_t>(embedding_dim) * sizeof(float);
214
2/2
✓ Branch 0 taken 16 times.
✓ Branch 1 taken 106 times.
122 thread_local std::vector<Value_t> handles;
215
2/2
✓ Branch 0 taken 16 times.
✓ Branch 1 taken 106 times.
122 thread_local std::vector<char> read_buffer;
216
1/2
✓ Branch 2 taken 122 times.
✗ Branch 3 not taken.
122 handles.assign(keys.Size(), kValueHandleNone);
217 const auto index_lookup_start =
218 54 stats != nullptr ? std::chrono::steady_clock::now()
219
2/2
✓ Branch 0 taken 54 times.
✓ Branch 1 taken 68 times.
122 : std::chrono::steady_clock::time_point{};
220
1/2
✓ Branch 1 taken 122 times.
✗ Branch 2 not taken.
122 if (keys.Size() > 0) {
221
1/2
✓ Branch 3 taken 122 times.
✗ Branch 4 not taken.
122 index_->BatchGet(keys, handles.data(), tid);
222 }
223
2/2
✓ Branch 0 taken 54 times.
✓ Branch 1 taken 68 times.
122 if (stats != nullptr) {
224 54 stats->index_lookup_ns = static_cast<std::uint64_t>(
225
1/2
✓ Branch 1 taken 54 times.
✗ Branch 2 not taken.
54 std::chrono::duration_cast< std::chrono::nanoseconds>(
226
0/2
✗ Branch 2 not taken.
✗ Branch 3 not taken.
108 std::chrono::steady_clock::now() - index_lookup_start)
227 54 .count());
228 }
229
230 122 std::uint64_t missing_zero_fill_ns = 0;
231 122 std::uint64_t missing_rows = 0;
232 const auto row_copy_start =
233 54 stats != nullptr ? std::chrono::steady_clock::now()
234
2/2
✓ Branch 0 taken 54 times.
✓ Branch 1 taken 68 times.
122 : std::chrono::steady_clock::time_point{};
235
4/4
✓ Branch 0 taken 12 times.
✓ Branch 1 taken 110 times.
✓ Branch 2 taken 12 times.
✓ Branch 3 taken 110 times.
134 if (default_value_size_hint_ == row_bytes &&
236
2/4
✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
✗ Branch 5 not taken.
24 value_store_->ReadFlatFixedRows(
237 12 handles.data(),
238 static_cast<size_t>(num_rows),
239 values,
240 row_bytes,
241 &missing_rows)) {
242
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
12 if (stats != nullptr) {
243 stats->zero_fill_ns = 0;
244 stats->row_copy_ns = static_cast<std::uint64_t>(
245 std::chrono::duration_cast< std::chrono::nanoseconds>(
246 std::chrono::steady_clock::now() - row_copy_start)
247 .count());
248 stats->missing_rows = missing_rows;
249 }
250 12 return true;
251 }
252
2/2
✓ Branch 0 taken 218 times.
✓ Branch 1 taken 54 times.
272 for (int64_t row = 0; row < num_rows; ++row) {
253 218 const Value_t handle = handles[static_cast<size_t>(row)];
254 218 float* dst = values + row * embedding_dim;
255
2/2
✓ Branch 0 taken 54 times.
✓ Branch 1 taken 164 times.
218 if (handle == kValueHandleNone) {
256 const auto missing_zero_start =
257 54 stats != nullptr ? std::chrono::steady_clock::now()
258
1/2
✓ Branch 0 taken 54 times.
✗ Branch 1 not taken.
54 : std::chrono::steady_clock::time_point{};
259 54 std::memset(dst, 0, row_bytes);
260
1/2
✓ Branch 0 taken 54 times.
✗ Branch 1 not taken.
54 if (stats != nullptr) {
261 54 missing_zero_fill_ns += static_cast<std::uint64_t>(
262
1/2
✓ Branch 1 taken 54 times.
✗ Branch 2 not taken.
54 std::chrono::duration_cast< std::chrono::nanoseconds>(
263
0/2
✗ Branch 2 not taken.
✗ Branch 3 not taken.
108 std::chrono::steady_clock::now() - missing_zero_start)
264 54 .count());
265 }
266 54 ++missing_rows;
267 54 continue;
268 54 }
269
3/4
✓ Branch 2 taken 164 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 110 times.
✓ Branch 5 taken 54 times.
164 if (const char* ptr = value_store_->DirectPtr(handle)) {
270
1/2
✓ Branch 0 taken 110 times.
✗ Branch 1 not taken.
110 if (default_value_size_hint_ != row_bytes) {
271
1/2
✓ Branch 2 taken 110 times.
✗ Branch 3 not taken.
110 const size_t slot_bytes = value_store_->SlotCapacity(handle);
272
2/2
✓ Branch 0 taken 38 times.
✓ Branch 1 taken 72 times.
110 if (slot_bytes != row_bytes) {
273
4/8
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 38 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 38 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 38 times.
✗ Branch 11 not taken.
76 LOG(ERROR) << "KVEngine::BatchGetFlat row size mismatch row=" << row
274
2/4
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
✓ Branch 5 taken 38 times.
✗ Branch 6 not taken.
38 << " key=" << keys[static_cast<int>(row)]
275
2/4
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 38 times.
✗ Branch 5 not taken.
38 << " expected_bytes=" << row_bytes
276
2/4
✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 38 times.
✗ Branch 5 not taken.
38 << " actual_bytes=" << slot_bytes;
277 38 return false;
278 }
279 }
280 72 std::memcpy(dst, ptr, row_bytes);
281 } else {
282
1/2
✓ Branch 1 taken 54 times.
✗ Branch 2 not taken.
54 read_buffer.resize(row_bytes + sizeof(float));
283 const size_t actual =
284
1/2
✓ Branch 4 taken 54 times.
✗ Branch 5 not taken.
54 value_store_->Read(handle, read_buffer.data(), read_buffer.size());
285
2/2
✓ Branch 0 taken 18 times.
✓ Branch 1 taken 36 times.
54 if (actual != row_bytes) {
286
4/8
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 18 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 18 times.
✗ Branch 11 not taken.
36 LOG(ERROR) << "KVEngine::BatchGetFlat read size mismatch row=" << row
287
2/4
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✓ Branch 5 taken 18 times.
✗ Branch 6 not taken.
18 << " key=" << keys[static_cast<int>(row)]
288
2/4
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 5 not taken.
18 << " expected_bytes=" << row_bytes
289
2/4
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 5 not taken.
18 << " actual_bytes=" << actual;
290 18 return false;
291 }
292 36 std::memcpy(dst, read_buffer.data(), row_bytes);
293 }
294 }
295
1/2
✓ Branch 0 taken 54 times.
✗ Branch 1 not taken.
54 if (stats != nullptr) {
296 54 stats->zero_fill_ns = missing_zero_fill_ns;
297 54 stats->row_copy_ns = static_cast<std::uint64_t>(
298
1/2
✓ Branch 1 taken 54 times.
✗ Branch 2 not taken.
54 std::chrono::duration_cast< std::chrono::nanoseconds>(
299
0/2
✗ Branch 2 not taken.
✗ Branch 3 not taken.
54 std::chrono::steady_clock::now() - row_copy_start)
300 54 .count());
301 54 stats->missing_rows = missing_rows;
302 }
303 54 return true;
304 }
305
306 bool BatchGetIndexOnly(base::ConstArray<uint64_t> keys,
307 unsigned tid,
308 BatchGetFlatStats* stats = nullptr) override {
309 thread_local std::vector<Value_t> handles;
310 handles.assign(keys.Size(), kValueHandleNone);
311 if (keys.Size() > 0) {
312 index_->BatchGet(keys, handles.data(), tid);
313 }
314 if (stats != nullptr) {
315 std::uint64_t missing_rows = 0;
316 for (const Value_t handle : handles) {
317 if (handle == kValueHandleNone) {
318 ++missing_rows;
319 }
320 }
321 stats->missing_rows = missing_rows;
322 }
323 return true;
324 }
325
326 bool BatchGetDirectFixedRows(
327 base::ConstArray<uint64_t> keys,
328 int64_t num_rows,
329 int64_t embedding_dim,
330 unsigned tid,
331 std::vector<DirectFixedRow>* rows,
332 BatchGetFlatStats* stats = nullptr) override {
333 if (rows == nullptr || num_rows < 0 || embedding_dim <= 0 ||
334 keys.Size() != static_cast<size_t>(num_rows)) {
335 return false;
336 }
337 const size_t row_bytes = static_cast<size_t>(embedding_dim) * sizeof(float);
338 if (default_value_size_hint_ != row_bytes) {
339 return false;
340 }
341 thread_local std::vector<Value_t> handles;
342 handles.assign(keys.Size(), kValueHandleNone);
343 const auto index_lookup_start =
344 stats != nullptr ? std::chrono::steady_clock::now()
345 : std::chrono::steady_clock::time_point{};
346 if (keys.Size() > 0) {
347 index_->BatchGet(keys, handles.data(), tid);
348 }
349 if (stats != nullptr) {
350 stats->index_lookup_ns = static_cast<std::uint64_t>(
351 std::chrono::duration_cast< std::chrono::nanoseconds>(
352 std::chrono::steady_clock::now() - index_lookup_start)
353 .count());
354 }
355
356 thread_local std::vector<ValueStore::DirectFixedRow> store_rows;
357 store_rows.resize(static_cast<size_t>(num_rows));
358 uint64_t missing_rows = 0;
359 if (!value_store_->GetDirectFixedRows(
360 handles.data(),
361 static_cast<size_t>(num_rows),
362 row_bytes,
363 store_rows.data(),
364 &missing_rows)) {
365 return false;
366 }
367 rows->resize(store_rows.size());
368 for (size_t i = 0; i < store_rows.size(); ++i) {
369 (*rows)[i] = DirectFixedRow{
370 store_rows[i].data, store_rows[i].size, store_rows[i].missing};
371 }
372 if (stats != nullptr) {
373 stats->missing_rows = missing_rows;
374 }
375 return true;
376 }
377
378 RDMABackingRegion GetRDMABackingRegion() const override {
379 if (!value_store_) {
380 return {};
381 }
382 return RDMABackingRegion{
383 value_store_->RDMABackingData(), value_store_->RDMABackingSize()};
384 }
385
386 6 bool ApplySgdUpdateFlat(
387 base::ConstArray<uint64_t> keys,
388 const float* grads,
389 int64_t num_rows,
390 int64_t embedding_dim,
391 float learning_rate,
392 uint8_t tag,
393 unsigned tid) override {
394
4/8
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 3 taken 6 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 6 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 6 times.
6 if (grads == nullptr || keys.Size() != static_cast<size_t>(num_rows) ||
395 embedding_dim <= 0) {
396 return false;
397 }
398 6 const int tag_bits = static_cast<int>(sizeof(tag) * 8);
399 6 const int shift = static_cast<int>(sizeof(uint64_t) * 8) - tag_bits;
400 6 const uint64_t key_mask = ~0ULL >> tag_bits;
401 6 const size_t row_bytes = static_cast<size_t>(embedding_dim) * sizeof(float);
402
1/2
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
6 std::vector<float> row(embedding_dim);
403
2/2
✓ Branch 0 taken 10 times.
✓ Branch 1 taken 6 times.
16 for (int64_t r = 0; r < num_rows; ++r) {
404 10 const uint64_t key = (static_cast<uint64_t>(tag) << shift) |
405 10 (keys[static_cast<size_t>(r)] & key_mask);
406 10 std::string current;
407
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 Get(key, current, tid);
408
2/2
✓ Branch 1 taken 8 times.
✓ Branch 2 taken 2 times.
10 if (current.size() == row_bytes) {
409 8 std::memcpy(row.data(), current.data(), row_bytes);
410 } else {
411
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 std::fill(row.begin(), row.end(), 0.0f);
412 }
413 10 const float* grad = grads + r * embedding_dim;
414
2/2
✓ Branch 0 taken 40 times.
✓ Branch 1 taken 10 times.
50 for (int64_t c = 0; c < embedding_dim; ++c) {
415 40 row[c] -= learning_rate * grad[c];
416 }
417
1/2
✓ Branch 2 taken 10 times.
✗ Branch 3 not taken.
10 PutInternal(key, row.data(), row_bytes, tid, false);
418 10 }
419 6 return true;
420 6 }
421
422 54 void BulkLoad(base::ConstArray<uint64_t> keys, const void* value) override {
423 54 const auto& j = config_.json_config_;
424
2/4
✓ Branch 1 taken 54 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 54 times.
✗ Branch 5 not taken.
54 const size_t value_size = j.at("value").value("default_value_size_hint", 0);
425
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 54 times.
54 if (value_size == 0) {
426 LOG(FATAL) << "KVEngine::BulkLoad requires value_size hint";
427 }
428
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 54 times.
54 if (keys.Size() == 0) {
429 return;
430 }
431 54 const char* data = reinterpret_cast<const char*>(value);
432 54 std::vector<ValueStore::WriteSpec> specs;
433
1/2
✓ Branch 2 taken 54 times.
✗ Branch 3 not taken.
54 specs.reserve(static_cast<size_t>(keys.Size()));
434
2/2
✓ Branch 1 taken 432 times.
✓ Branch 2 taken 54 times.
486 for (int i = 0; i < keys.Size(); ++i) {
435
1/2
✓ Branch 1 taken 432 times.
✗ Branch 2 not taken.
432 specs.push_back(ValueStore::WriteSpec{data + i * value_size, value_size});
436 }
437
1/2
✓ Branch 2 taken 54 times.
✗ Branch 3 not taken.
54 std::vector<uint64_t> handles = value_store_->BatchAllocAndWrite(specs);
438
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 54 times.
54 if (handles.size() != static_cast<size_t>(keys.Size())) {
439 LOG(FATAL) << "KVEngine::BulkLoad allocation result size mismatch";
440 }
441
2/2
✓ Branch 1 taken 432 times.
✓ Branch 2 taken 54 times.
486 for (int i = 0; i < keys.Size(); ++i) {
442
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 432 times.
432 if (handles[static_cast<size_t>(i)] == kValueHandleNone) {
443 LOG(FATAL) << "KVEngine bulk value allocation failed, key=" << keys[i]
444 << " size=" << value_size;
445 }
446 }
447
1/2
✓ Branch 3 taken 54 times.
✗ Branch 4 not taken.
54 index_->BatchPut(keys, handles.data(), 0);
448 54 }
449
450 void Util() override {
451 LOG(INFO) << "KVEngine index utilization=" << index_->Utilization()
452 << " value=" << value_store_->GetInfo();
453 }
454
455 void DebugInfo() const override {
456 index_->DebugInfo();
457 LOG(INFO) << value_store_->GetInfo();
458 }
459
460 std::string ExtraResultFields() const override {
461 return value_store_ ? value_store_->ExtraResultFields() : "";
462 }
463
464 private:
465 354022 void PutInternal(uint64_t key,
466 const void* data,
467 size_t size,
468 unsigned tid,
469 bool emit_fence) {
470 (void)tid;
471 (void)emit_fence;
472 354022 Value_t new_handle = value_store_->AllocAndWrite(data, size);
473
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 354022 times.
354022 if (new_handle == kValueHandleNone) {
474 LOG(FATAL) << "KVEngine value allocation failed, key=" << key
475 << " size=" << size;
476 return;
477 }
478 354022 Value_t old_handle = index_->Put(key, new_handle, tid);
479
2/2
✓ Branch 0 taken 156442 times.
✓ Branch 1 taken 197580 times.
354022 if (old_handle != kValueHandleNone) {
480 156442 value_store_->Retire(old_handle);
481 }
482 }
483
484 BaseKVConfig config_;
485 std::unique_ptr<Index> index_;
486 std::unique_ptr<ValueStore> value_store_;
487 int num_threads_ = 0;
488 size_t default_value_size_hint_ = 0;
489 };
490
491 FACTORY_REGISTER(
492 BaseKV, KVEngineComposite, KVEngineComposite, const BaseKVConfig&);
493