storage/kv_engine/engine_composite.h
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #pragma once | ||
| 2 | |||
| 3 | #include <algorithm> | ||
| 4 | #include <chrono> | ||
| 5 | #include <cstring> | ||
| 6 | #include <memory> | ||
| 7 | #include <stdexcept> | ||
| 8 | #include <unordered_set> | ||
| 9 | #include <vector> | ||
| 10 | |||
| 11 | #include "base/factory.h" | ||
| 12 | #include "storage/index/dram/extendible_hash_index.h" | ||
| 13 | #include "storage/index/dram/pet_hash_index.h" | ||
| 14 | #include "storage/index/dram/unordered_map_index.h" | ||
| 15 | #include "storage/kv_engine/base_kv.h" | ||
| 16 | #include "storage/value_store/dram_value_store.h" | ||
| 17 | #include "storage/value_store/hybrid_value_store.h" | ||
| 18 | #include "storage/value_store/ssd_value_store.h" | ||
| 19 | |||
| 20 | class KVEngineComposite : public BaseKV { | ||
| 21 | public: | ||
| 22 | KVEngineComposite(std::unique_ptr<Index> index, | ||
| 23 | std::unique_ptr<ValueStore> value_store, | ||
| 24 | int num_threads = 0) | ||
| 25 | : BaseKV(BaseKVConfig{}), | ||
| 26 | index_(std::move(index)), | ||
| 27 | value_store_(std::move(value_store)), | ||
| 28 | num_threads_(num_threads) {} | ||
| 29 | |||
| 30 | 1032 | explicit KVEngineComposite(const BaseKVConfig& config) : BaseKV(config) { | |
| 31 | 1032 | config_ = config; | |
| 32 | 1032 | const auto& j = config.json_config_; | |
| 33 |
3/6✓ Branch 1 taken 1032 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 1032 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 1032 times.
✗ Branch 8 not taken.
|
1032 | const std::string index_type = j.at("index").at("type").get<std::string>(); |
| 34 |
3/6✓ Branch 1 taken 1032 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 1032 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 1032 times.
✗ Branch 8 not taken.
|
1032 | const std::string value_type = j.at("value").at("type").get<std::string>(); |
| 35 | using IF = base::Factory<Index, const BaseKVConfig&>; | ||
| 36 | using VF = base::Factory<ValueStore, const BaseKVConfig&>; | ||
| 37 |
1/2✓ Branch 1 taken 1032 times.
✗ Branch 2 not taken.
|
1032 | index_.reset(IF::NewInstance(index_type, config)); |
| 38 |
2/2✓ Branch 1 taken 1030 times.
✓ Branch 2 taken 2 times.
|
1032 | value_store_.reset(VF::NewInstance(value_type, config)); |
| 39 | 1030 | num_threads_ = config.num_threads_; | |
| 40 | 1030 | default_value_size_hint_ = | |
| 41 |
2/4✓ Branch 1 taken 1030 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 1030 times.
✗ Branch 5 not taken.
|
1030 | j.at("value").value("default_value_size_hint", 0); |
| 42 |
3/6✓ Branch 1 taken 1030 times.
✗ Branch 2 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 1030 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1030 times.
|
1030 | if (!index_ || !value_store_) { |
| 43 | ✗ | throw std::runtime_error("failed to create KVEngine components"); | |
| 44 | } | ||
| 45 | 1042 | } | |
| 46 | |||
| 47 | 201374 | void Get(const uint64_t key, std::string& value, unsigned tid) override { | |
| 48 | (void)tid; | ||
| 49 | 201374 | Value_t handle = kValueHandleNone; | |
| 50 |
1/2✓ Branch 2 taken 201374 times.
✗ Branch 3 not taken.
|
201374 | index_->Get(key, handle); |
| 51 |
2/2✓ Branch 0 taken 6280 times.
✓ Branch 1 taken 195094 times.
|
201374 | if (handle == kValueHandleNone) { |
| 52 | 6280 | value.clear(); | |
| 53 | 149054 | return; | |
| 54 | } | ||
| 55 |
3/4✓ Branch 2 taken 195094 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 142774 times.
✓ Branch 5 taken 52320 times.
|
195094 | if (const char* ptr = value_store_->DirectPtr(handle)) { |
| 56 |
2/4✓ Branch 2 taken 142774 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 142774 times.
✗ Branch 6 not taken.
|
142774 | value.resize(value_store_->SlotCapacity(handle)); |
| 57 | 142774 | std::memcpy(value.data(), ptr, value.size()); | |
| 58 | 142774 | return; | |
| 59 | } | ||
| 60 |
2/4✓ Branch 2 taken 52320 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 52320 times.
✗ Branch 6 not taken.
|
52320 | value.resize(value_store_->SlotCapacity(handle)); |
| 61 | const size_t actual = | ||
| 62 |
1/2✓ Branch 4 taken 52320 times.
✗ Branch 5 not taken.
|
52320 | value_store_->Read(handle, value.data(), value.size()); |
| 63 |
1/2✓ Branch 1 taken 52320 times.
✗ Branch 2 not taken.
|
52320 | value.resize(actual); |
| 64 | } | ||
| 65 | |||
| 66 | ✗ | bool Exists(const uint64_t key, unsigned tid) override { | |
| 67 | (void)tid; | ||
| 68 | ✗ | Value_t handle = kValueHandleNone; | |
| 69 | ✗ | index_->Get(key, handle); | |
| 70 | ✗ | return handle != kValueHandleNone; | |
| 71 | } | ||
| 72 | |||
| 73 | 354012 | void Put(const uint64_t key, | |
| 74 | const std::string_view& value, | ||
| 75 | unsigned tid) override { | ||
| 76 | 354012 | PutInternal(key, value.data(), value.size(), tid, true); | |
| 77 | 354012 | } | |
| 78 | |||
| 79 | 292 | void BatchPut(base::ConstArray<uint64_t> keys, | |
| 80 | std::vector<base::ConstArray<float>>* values, | ||
| 81 | unsigned tid) override { | ||
| 82 |
3/6✓ Branch 0 taken 292 times.
✗ Branch 1 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 292 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 292 times.
|
292 | if (values == nullptr || keys.Size() != static_cast<int>(values->size())) { |
| 83 | ✗ | LOG(FATAL) << "KVEngine::BatchPut size mismatch"; | |
| 84 | } | ||
| 85 | (void)tid; | ||
| 86 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 292 times.
|
292 | if (keys.Size() == 0) { |
| 87 | ✗ | return; | |
| 88 | } | ||
| 89 | |||
| 90 | 292 | std::unordered_set<uint64_t> seen_keys; | |
| 91 |
1/2✓ Branch 2 taken 292 times.
✗ Branch 3 not taken.
|
292 | seen_keys.reserve(static_cast<size_t>(keys.Size())); |
| 92 | 292 | bool has_duplicate_key = false; | |
| 93 |
2/2✓ Branch 1 taken 740 times.
✓ Branch 2 taken 292 times.
|
1032 | for (int i = 0; i < keys.Size(); ++i) { |
| 94 |
2/4✓ Branch 2 taken 740 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 740 times.
|
740 | if (!seen_keys.insert(keys[i]).second) { |
| 95 | ✗ | has_duplicate_key = true; | |
| 96 | ✗ | break; | |
| 97 | } | ||
| 98 | } | ||
| 99 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 292 times.
|
292 | if (has_duplicate_key) { |
| 100 | ✗ | for (int i = 0; i < keys.Size(); ++i) { | |
| 101 | ✗ | const auto& item = (*values)[i]; | |
| 102 | ✗ | PutInternal(keys[i], | |
| 103 | ✗ | item.Data(), | |
| 104 | ✗ | static_cast<size_t>(item.Size()) * sizeof(float), | |
| 105 | tid, | ||
| 106 | false); | ||
| 107 | } | ||
| 108 | ✗ | return; | |
| 109 | } | ||
| 110 | |||
| 111 | struct PutItem { | ||
| 112 | uint64_t key = 0; | ||
| 113 | ValueStore::WriteSpec spec{}; | ||
| 114 | }; | ||
| 115 | 292 | std::vector<PutItem> items; | |
| 116 |
1/2✓ Branch 2 taken 292 times.
✗ Branch 3 not taken.
|
292 | items.reserve(static_cast<size_t>(keys.Size())); |
| 117 | |||
| 118 |
2/2✓ Branch 1 taken 740 times.
✓ Branch 2 taken 292 times.
|
1032 | for (int i = 0; i < keys.Size(); ++i) { |
| 119 | 740 | const auto& item = (*values)[i]; | |
| 120 | 740 | const void* data = item.Data(); | |
| 121 | 740 | const size_t size = static_cast<size_t>(item.Size()) * sizeof(float); | |
| 122 |
1/2✓ Branch 2 taken 740 times.
✗ Branch 3 not taken.
|
740 | items.push_back(PutItem{keys[i], ValueStore::WriteSpec{data, size}}); |
| 123 | } | ||
| 124 | |||
| 125 | 292 | std::vector<ValueStore::WriteSpec> specs; | |
| 126 |
1/2✓ Branch 2 taken 292 times.
✗ Branch 3 not taken.
|
292 | specs.reserve(items.size()); |
| 127 |
2/2✓ Branch 5 taken 740 times.
✓ Branch 6 taken 292 times.
|
1032 | for (const auto& item : items) { |
| 128 |
1/2✓ Branch 1 taken 740 times.
✗ Branch 2 not taken.
|
740 | specs.push_back(item.spec); |
| 129 | } | ||
| 130 |
1/2✓ Branch 2 taken 292 times.
✗ Branch 3 not taken.
|
292 | const auto new_handles = value_store_->BatchAllocAndWrite(specs); |
| 131 |
1/2✗ Branch 2 not taken.
✓ Branch 3 taken 292 times.
|
292 | if (new_handles.size() != items.size()) { |
| 132 | ✗ | LOG(FATAL) << "KVEngine::BatchPut allocation result size mismatch"; | |
| 133 | } | ||
| 134 |
2/2✓ Branch 1 taken 740 times.
✓ Branch 2 taken 292 times.
|
1032 | for (size_t i = 0; i < items.size(); ++i) { |
| 135 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 740 times.
|
740 | if (new_handles[i] == kValueHandleNone) { |
| 136 | ✗ | LOG(FATAL) << "KVEngine batch value allocation failed, key=" | |
| 137 | ✗ | << items[i].key << " size=" << items[i].spec.size; | |
| 138 | } | ||
| 139 | } | ||
| 140 | |||
| 141 |
2/2✓ Branch 1 taken 740 times.
✓ Branch 2 taken 292 times.
|
1032 | for (size_t i = 0; i < items.size(); ++i) { |
| 142 |
1/2✓ Branch 4 taken 740 times.
✗ Branch 5 not taken.
|
740 | Value_t old_handle = index_->Put(items[i].key, new_handles[i], tid); |
| 143 |
2/2✓ Branch 0 taken 216 times.
✓ Branch 1 taken 524 times.
|
740 | if (old_handle != kValueHandleNone) { |
| 144 |
1/2✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
|
216 | value_store_->Retire(old_handle); |
| 145 | } | ||
| 146 | } | ||
| 147 |
1/2✓ Branch 4 taken 292 times.
✗ Branch 5 not taken.
|
292 | } |
| 148 | |||
| 149 | 1536 | void BatchGet(base::ConstArray<uint64_t> keys, | |
| 150 | std::vector<base::ConstArray<float>>* values, | ||
| 151 | unsigned tid) override { | ||
| 152 | (void)tid; | ||
| 153 |
1/2✓ Branch 2 taken 1536 times.
✗ Branch 3 not taken.
|
1536 | values->resize(keys.Size()); |
| 154 |
2/2✓ Branch 0 taken 252 times.
✓ Branch 1 taken 1284 times.
|
1536 | thread_local std::vector<Value_t> handles; |
| 155 |
2/2✓ Branch 0 taken 252 times.
✓ Branch 1 taken 1284 times.
|
1536 | thread_local std::vector<std::vector<float>> buffers; |
| 156 |
1/2✓ Branch 2 taken 1536 times.
✗ Branch 3 not taken.
|
1536 | handles.assign(keys.Size(), kValueHandleNone); |
| 157 | 1536 | buffers.clear(); | |
| 158 |
1/2✓ Branch 2 taken 1536 times.
✗ Branch 3 not taken.
|
1536 | buffers.resize(keys.Size()); |
| 159 | |||
| 160 |
1/2✓ Branch 1 taken 1536 times.
✗ Branch 2 not taken.
|
1536 | if (keys.Size() > 0) { |
| 161 |
1/2✓ Branch 3 taken 1536 times.
✗ Branch 4 not taken.
|
1536 | index_->BatchGet(keys, handles.data(), tid); |
| 162 | } | ||
| 163 | 1536 | std::vector<uint64_t> batch_handles; | |
| 164 | 1536 | std::vector<size_t> batch_indices; | |
| 165 |
1/2✓ Branch 2 taken 1536 times.
✗ Branch 3 not taken.
|
1536 | batch_handles.reserve(static_cast<size_t>(keys.Size())); |
| 166 |
1/2✓ Branch 2 taken 1536 times.
✗ Branch 3 not taken.
|
1536 | batch_indices.reserve(static_cast<size_t>(keys.Size())); |
| 167 |
2/2✓ Branch 1 taken 67426 times.
✓ Branch 2 taken 1536 times.
|
68962 | for (int i = 0; i < keys.Size(); ++i) { |
| 168 |
2/2✓ Branch 1 taken 324 times.
✓ Branch 2 taken 67102 times.
|
67426 | if (handles[i] == kValueHandleNone) { |
| 169 | 324 | (*values)[i] = base::ConstArray<float>(); | |
| 170 | 324 | continue; | |
| 171 | } | ||
| 172 |
3/4✓ Branch 3 taken 67102 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 45542 times.
✓ Branch 6 taken 21560 times.
|
67102 | if (const char* ptr = value_store_->DirectPtr(handles[i])) { |
| 173 |
1/2✓ Branch 3 taken 45542 times.
✗ Branch 4 not taken.
|
45542 | const size_t bytes = value_store_->SlotCapacity(handles[i]); |
| 174 | 45542 | (*values)[i] = base::ConstArray<float>( | |
| 175 | reinterpret_cast<float*>(const_cast<char*>(ptr)), | ||
| 176 | 45542 | bytes / sizeof(float)); | |
| 177 | 45542 | continue; | |
| 178 | 45542 | } | |
| 179 |
1/2✓ Branch 2 taken 21560 times.
✗ Branch 3 not taken.
|
21560 | batch_handles.push_back(handles[i]); |
| 180 |
1/2✓ Branch 1 taken 21560 times.
✗ Branch 2 not taken.
|
21560 | batch_indices.push_back(static_cast<size_t>(i)); |
| 181 | } | ||
| 182 | |||
| 183 | 1536 | std::vector<ValueStore::ReadResult> batch_results; | |
| 184 |
2/2✓ Branch 1 taken 410 times.
✓ Branch 2 taken 1126 times.
|
1536 | if (!batch_handles.empty()) { |
| 185 |
1/2✓ Branch 2 taken 410 times.
✗ Branch 3 not taken.
|
410 | value_store_->BatchRead(batch_handles, batch_results); |
| 186 |
1/2✗ Branch 2 not taken.
✓ Branch 3 taken 410 times.
|
410 | if (batch_results.size() != batch_indices.size()) { |
| 187 | ✗ | LOG(FATAL) << "KVEngine::BatchGet read result size mismatch"; | |
| 188 | } | ||
| 189 |
2/2✓ Branch 1 taken 21560 times.
✓ Branch 2 taken 410 times.
|
21970 | for (size_t i = 0; i < batch_indices.size(); ++i) { |
| 190 | 21560 | const size_t idx = batch_indices[i]; | |
| 191 | 21560 | const auto& result = batch_results[i]; | |
| 192 |
1/2✓ Branch 3 taken 21560 times.
✗ Branch 4 not taken.
|
21560 | buffers[idx].resize(result.data.size() / sizeof(float)); |
| 193 |
1/2✓ Branch 1 taken 21560 times.
✗ Branch 2 not taken.
|
21560 | if (!result.data.empty()) { |
| 194 | 43120 | std::memcpy( | |
| 195 | 21560 | buffers[idx].data(), result.data.data(), result.data.size()); | |
| 196 | } | ||
| 197 | 21560 | (*values)[idx] = | |
| 198 | 43120 | base::ConstArray<float>(buffers[idx].data(), buffers[idx].size()); | |
| 199 | } | ||
| 200 | } | ||
| 201 | 1536 | } | |
| 202 | |||
| 203 | 122 | bool BatchGetFlat(base::ConstArray<uint64_t> keys, | |
| 204 | float* values, | ||
| 205 | int64_t num_rows, | ||
| 206 | int64_t embedding_dim, | ||
| 207 | unsigned tid, | ||
| 208 | BatchGetFlatStats* stats = nullptr) override { | ||
| 209 |
4/8✓ Branch 0 taken 122 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 122 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 122 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 122 times.
|
244 | if (values == nullptr || num_rows < 0 || embedding_dim <= 0 || |
| 210 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 122 times.
|
122 | keys.Size() != static_cast<size_t>(num_rows)) { |
| 211 | ✗ | return false; | |
| 212 | } | ||
| 213 | 122 | const size_t row_bytes = static_cast<size_t>(embedding_dim) * sizeof(float); | |
| 214 |
2/2✓ Branch 0 taken 16 times.
✓ Branch 1 taken 106 times.
|
122 | thread_local std::vector<Value_t> handles; |
| 215 |
2/2✓ Branch 0 taken 16 times.
✓ Branch 1 taken 106 times.
|
122 | thread_local std::vector<char> read_buffer; |
| 216 |
1/2✓ Branch 2 taken 122 times.
✗ Branch 3 not taken.
|
122 | handles.assign(keys.Size(), kValueHandleNone); |
| 217 | const auto index_lookup_start = | ||
| 218 | 54 | stats != nullptr ? std::chrono::steady_clock::now() | |
| 219 |
2/2✓ Branch 0 taken 54 times.
✓ Branch 1 taken 68 times.
|
122 | : std::chrono::steady_clock::time_point{}; |
| 220 |
1/2✓ Branch 1 taken 122 times.
✗ Branch 2 not taken.
|
122 | if (keys.Size() > 0) { |
| 221 |
1/2✓ Branch 3 taken 122 times.
✗ Branch 4 not taken.
|
122 | index_->BatchGet(keys, handles.data(), tid); |
| 222 | } | ||
| 223 |
2/2✓ Branch 0 taken 54 times.
✓ Branch 1 taken 68 times.
|
122 | if (stats != nullptr) { |
| 224 | 54 | stats->index_lookup_ns = static_cast<std::uint64_t>( | |
| 225 |
1/2✓ Branch 1 taken 54 times.
✗ Branch 2 not taken.
|
54 | std::chrono::duration_cast< std::chrono::nanoseconds>( |
| 226 |
0/2✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
108 | std::chrono::steady_clock::now() - index_lookup_start) |
| 227 | 54 | .count()); | |
| 228 | } | ||
| 229 | |||
| 230 | 122 | std::uint64_t missing_zero_fill_ns = 0; | |
| 231 | 122 | std::uint64_t missing_rows = 0; | |
| 232 | const auto row_copy_start = | ||
| 233 | 54 | stats != nullptr ? std::chrono::steady_clock::now() | |
| 234 |
2/2✓ Branch 0 taken 54 times.
✓ Branch 1 taken 68 times.
|
122 | : std::chrono::steady_clock::time_point{}; |
| 235 |
4/4✓ Branch 0 taken 12 times.
✓ Branch 1 taken 110 times.
✓ Branch 2 taken 12 times.
✓ Branch 3 taken 110 times.
|
134 | if (default_value_size_hint_ == row_bytes && |
| 236 |
2/4✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
✗ Branch 5 not taken.
|
24 | value_store_->ReadFlatFixedRows( |
| 237 | 12 | handles.data(), | |
| 238 | static_cast<size_t>(num_rows), | ||
| 239 | values, | ||
| 240 | row_bytes, | ||
| 241 | &missing_rows)) { | ||
| 242 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
|
12 | if (stats != nullptr) { |
| 243 | ✗ | stats->zero_fill_ns = 0; | |
| 244 | ✗ | stats->row_copy_ns = static_cast<std::uint64_t>( | |
| 245 | ✗ | std::chrono::duration_cast< std::chrono::nanoseconds>( | |
| 246 | ✗ | std::chrono::steady_clock::now() - row_copy_start) | |
| 247 | ✗ | .count()); | |
| 248 | ✗ | stats->missing_rows = missing_rows; | |
| 249 | } | ||
| 250 | 12 | return true; | |
| 251 | } | ||
| 252 |
2/2✓ Branch 0 taken 218 times.
✓ Branch 1 taken 54 times.
|
272 | for (int64_t row = 0; row < num_rows; ++row) { |
| 253 | 218 | const Value_t handle = handles[static_cast<size_t>(row)]; | |
| 254 | 218 | float* dst = values + row * embedding_dim; | |
| 255 |
2/2✓ Branch 0 taken 54 times.
✓ Branch 1 taken 164 times.
|
218 | if (handle == kValueHandleNone) { |
| 256 | const auto missing_zero_start = | ||
| 257 | 54 | stats != nullptr ? std::chrono::steady_clock::now() | |
| 258 |
1/2✓ Branch 0 taken 54 times.
✗ Branch 1 not taken.
|
54 | : std::chrono::steady_clock::time_point{}; |
| 259 | 54 | std::memset(dst, 0, row_bytes); | |
| 260 |
1/2✓ Branch 0 taken 54 times.
✗ Branch 1 not taken.
|
54 | if (stats != nullptr) { |
| 261 | 54 | missing_zero_fill_ns += static_cast<std::uint64_t>( | |
| 262 |
1/2✓ Branch 1 taken 54 times.
✗ Branch 2 not taken.
|
54 | std::chrono::duration_cast< std::chrono::nanoseconds>( |
| 263 |
0/2✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
108 | std::chrono::steady_clock::now() - missing_zero_start) |
| 264 | 54 | .count()); | |
| 265 | } | ||
| 266 | 54 | ++missing_rows; | |
| 267 | 54 | continue; | |
| 268 | 54 | } | |
| 269 |
3/4✓ Branch 2 taken 164 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 110 times.
✓ Branch 5 taken 54 times.
|
164 | if (const char* ptr = value_store_->DirectPtr(handle)) { |
| 270 |
1/2✓ Branch 0 taken 110 times.
✗ Branch 1 not taken.
|
110 | if (default_value_size_hint_ != row_bytes) { |
| 271 |
1/2✓ Branch 2 taken 110 times.
✗ Branch 3 not taken.
|
110 | const size_t slot_bytes = value_store_->SlotCapacity(handle); |
| 272 |
2/2✓ Branch 0 taken 38 times.
✓ Branch 1 taken 72 times.
|
110 | if (slot_bytes != row_bytes) { |
| 273 |
4/8✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 38 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 38 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 38 times.
✗ Branch 11 not taken.
|
76 | LOG(ERROR) << "KVEngine::BatchGetFlat row size mismatch row=" << row |
| 274 |
2/4✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
✓ Branch 5 taken 38 times.
✗ Branch 6 not taken.
|
38 | << " key=" << keys[static_cast<int>(row)] |
| 275 |
2/4✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 38 times.
✗ Branch 5 not taken.
|
38 | << " expected_bytes=" << row_bytes |
| 276 |
2/4✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 38 times.
✗ Branch 5 not taken.
|
38 | << " actual_bytes=" << slot_bytes; |
| 277 | 38 | return false; | |
| 278 | } | ||
| 279 | } | ||
| 280 | 72 | std::memcpy(dst, ptr, row_bytes); | |
| 281 | } else { | ||
| 282 |
1/2✓ Branch 1 taken 54 times.
✗ Branch 2 not taken.
|
54 | read_buffer.resize(row_bytes + sizeof(float)); |
| 283 | const size_t actual = | ||
| 284 |
1/2✓ Branch 4 taken 54 times.
✗ Branch 5 not taken.
|
54 | value_store_->Read(handle, read_buffer.data(), read_buffer.size()); |
| 285 |
2/2✓ Branch 0 taken 18 times.
✓ Branch 1 taken 36 times.
|
54 | if (actual != row_bytes) { |
| 286 |
4/8✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 18 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 18 times.
✗ Branch 11 not taken.
|
36 | LOG(ERROR) << "KVEngine::BatchGetFlat read size mismatch row=" << row |
| 287 |
2/4✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✓ Branch 5 taken 18 times.
✗ Branch 6 not taken.
|
18 | << " key=" << keys[static_cast<int>(row)] |
| 288 |
2/4✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 5 not taken.
|
18 | << " expected_bytes=" << row_bytes |
| 289 |
2/4✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 5 not taken.
|
18 | << " actual_bytes=" << actual; |
| 290 | 18 | return false; | |
| 291 | } | ||
| 292 | 36 | std::memcpy(dst, read_buffer.data(), row_bytes); | |
| 293 | } | ||
| 294 | } | ||
| 295 |
1/2✓ Branch 0 taken 54 times.
✗ Branch 1 not taken.
|
54 | if (stats != nullptr) { |
| 296 | 54 | stats->zero_fill_ns = missing_zero_fill_ns; | |
| 297 | 54 | stats->row_copy_ns = static_cast<std::uint64_t>( | |
| 298 |
1/2✓ Branch 1 taken 54 times.
✗ Branch 2 not taken.
|
54 | std::chrono::duration_cast< std::chrono::nanoseconds>( |
| 299 |
0/2✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
54 | std::chrono::steady_clock::now() - row_copy_start) |
| 300 | 54 | .count()); | |
| 301 | 54 | stats->missing_rows = missing_rows; | |
| 302 | } | ||
| 303 | 54 | return true; | |
| 304 | } | ||
| 305 | |||
| 306 | ✗ | bool BatchGetIndexOnly(base::ConstArray<uint64_t> keys, | |
| 307 | unsigned tid, | ||
| 308 | BatchGetFlatStats* stats = nullptr) override { | ||
| 309 | ✗ | thread_local std::vector<Value_t> handles; | |
| 310 | ✗ | handles.assign(keys.Size(), kValueHandleNone); | |
| 311 | ✗ | if (keys.Size() > 0) { | |
| 312 | ✗ | index_->BatchGet(keys, handles.data(), tid); | |
| 313 | } | ||
| 314 | ✗ | if (stats != nullptr) { | |
| 315 | ✗ | std::uint64_t missing_rows = 0; | |
| 316 | ✗ | for (const Value_t handle : handles) { | |
| 317 | ✗ | if (handle == kValueHandleNone) { | |
| 318 | ✗ | ++missing_rows; | |
| 319 | } | ||
| 320 | } | ||
| 321 | ✗ | stats->missing_rows = missing_rows; | |
| 322 | } | ||
| 323 | ✗ | return true; | |
| 324 | } | ||
| 325 | |||
| 326 | ✗ | bool BatchGetDirectFixedRows( | |
| 327 | base::ConstArray<uint64_t> keys, | ||
| 328 | int64_t num_rows, | ||
| 329 | int64_t embedding_dim, | ||
| 330 | unsigned tid, | ||
| 331 | std::vector<DirectFixedRow>* rows, | ||
| 332 | BatchGetFlatStats* stats = nullptr) override { | ||
| 333 | ✗ | if (rows == nullptr || num_rows < 0 || embedding_dim <= 0 || | |
| 334 | ✗ | keys.Size() != static_cast<size_t>(num_rows)) { | |
| 335 | ✗ | return false; | |
| 336 | } | ||
| 337 | ✗ | const size_t row_bytes = static_cast<size_t>(embedding_dim) * sizeof(float); | |
| 338 | ✗ | if (default_value_size_hint_ != row_bytes) { | |
| 339 | ✗ | return false; | |
| 340 | } | ||
| 341 | ✗ | thread_local std::vector<Value_t> handles; | |
| 342 | ✗ | handles.assign(keys.Size(), kValueHandleNone); | |
| 343 | const auto index_lookup_start = | ||
| 344 | ✗ | stats != nullptr ? std::chrono::steady_clock::now() | |
| 345 | ✗ | : std::chrono::steady_clock::time_point{}; | |
| 346 | ✗ | if (keys.Size() > 0) { | |
| 347 | ✗ | index_->BatchGet(keys, handles.data(), tid); | |
| 348 | } | ||
| 349 | ✗ | if (stats != nullptr) { | |
| 350 | ✗ | stats->index_lookup_ns = static_cast<std::uint64_t>( | |
| 351 | ✗ | std::chrono::duration_cast< std::chrono::nanoseconds>( | |
| 352 | ✗ | std::chrono::steady_clock::now() - index_lookup_start) | |
| 353 | ✗ | .count()); | |
| 354 | } | ||
| 355 | |||
| 356 | ✗ | thread_local std::vector<ValueStore::DirectFixedRow> store_rows; | |
| 357 | ✗ | store_rows.resize(static_cast<size_t>(num_rows)); | |
| 358 | ✗ | uint64_t missing_rows = 0; | |
| 359 | ✗ | if (!value_store_->GetDirectFixedRows( | |
| 360 | ✗ | handles.data(), | |
| 361 | static_cast<size_t>(num_rows), | ||
| 362 | row_bytes, | ||
| 363 | store_rows.data(), | ||
| 364 | &missing_rows)) { | ||
| 365 | ✗ | return false; | |
| 366 | } | ||
| 367 | ✗ | rows->resize(store_rows.size()); | |
| 368 | ✗ | for (size_t i = 0; i < store_rows.size(); ++i) { | |
| 369 | ✗ | (*rows)[i] = DirectFixedRow{ | |
| 370 | ✗ | store_rows[i].data, store_rows[i].size, store_rows[i].missing}; | |
| 371 | } | ||
| 372 | ✗ | if (stats != nullptr) { | |
| 373 | ✗ | stats->missing_rows = missing_rows; | |
| 374 | } | ||
| 375 | ✗ | return true; | |
| 376 | } | ||
| 377 | |||
| 378 | ✗ | RDMABackingRegion GetRDMABackingRegion() const override { | |
| 379 | ✗ | if (!value_store_) { | |
| 380 | ✗ | return {}; | |
| 381 | } | ||
| 382 | return RDMABackingRegion{ | ||
| 383 | ✗ | value_store_->RDMABackingData(), value_store_->RDMABackingSize()}; | |
| 384 | } | ||
| 385 | |||
| 386 | 6 | bool ApplySgdUpdateFlat( | |
| 387 | base::ConstArray<uint64_t> keys, | ||
| 388 | const float* grads, | ||
| 389 | int64_t num_rows, | ||
| 390 | int64_t embedding_dim, | ||
| 391 | float learning_rate, | ||
| 392 | uint8_t tag, | ||
| 393 | unsigned tid) override { | ||
| 394 |
4/8✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 3 taken 6 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 6 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 6 times.
|
6 | if (grads == nullptr || keys.Size() != static_cast<size_t>(num_rows) || |
| 395 | embedding_dim <= 0) { | ||
| 396 | ✗ | return false; | |
| 397 | } | ||
| 398 | 6 | const int tag_bits = static_cast<int>(sizeof(tag) * 8); | |
| 399 | 6 | const int shift = static_cast<int>(sizeof(uint64_t) * 8) - tag_bits; | |
| 400 | 6 | const uint64_t key_mask = ~0ULL >> tag_bits; | |
| 401 | 6 | const size_t row_bytes = static_cast<size_t>(embedding_dim) * sizeof(float); | |
| 402 |
1/2✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
|
6 | std::vector<float> row(embedding_dim); |
| 403 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 6 times.
|
16 | for (int64_t r = 0; r < num_rows; ++r) { |
| 404 | 10 | const uint64_t key = (static_cast<uint64_t>(tag) << shift) | | |
| 405 | 10 | (keys[static_cast<size_t>(r)] & key_mask); | |
| 406 | 10 | std::string current; | |
| 407 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | Get(key, current, tid); |
| 408 |
2/2✓ Branch 1 taken 8 times.
✓ Branch 2 taken 2 times.
|
10 | if (current.size() == row_bytes) { |
| 409 | 8 | std::memcpy(row.data(), current.data(), row_bytes); | |
| 410 | } else { | ||
| 411 |
1/2✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
|
2 | std::fill(row.begin(), row.end(), 0.0f); |
| 412 | } | ||
| 413 | 10 | const float* grad = grads + r * embedding_dim; | |
| 414 |
2/2✓ Branch 0 taken 40 times.
✓ Branch 1 taken 10 times.
|
50 | for (int64_t c = 0; c < embedding_dim; ++c) { |
| 415 | 40 | row[c] -= learning_rate * grad[c]; | |
| 416 | } | ||
| 417 |
1/2✓ Branch 2 taken 10 times.
✗ Branch 3 not taken.
|
10 | PutInternal(key, row.data(), row_bytes, tid, false); |
| 418 | 10 | } | |
| 419 | 6 | return true; | |
| 420 | 6 | } | |
| 421 | |||
| 422 | 54 | void BulkLoad(base::ConstArray<uint64_t> keys, const void* value) override { | |
| 423 | 54 | const auto& j = config_.json_config_; | |
| 424 |
2/4✓ Branch 1 taken 54 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 54 times.
✗ Branch 5 not taken.
|
54 | const size_t value_size = j.at("value").value("default_value_size_hint", 0); |
| 425 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 54 times.
|
54 | if (value_size == 0) { |
| 426 | ✗ | LOG(FATAL) << "KVEngine::BulkLoad requires value_size hint"; | |
| 427 | } | ||
| 428 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 54 times.
|
54 | if (keys.Size() == 0) { |
| 429 | ✗ | return; | |
| 430 | } | ||
| 431 | 54 | const char* data = reinterpret_cast<const char*>(value); | |
| 432 | 54 | std::vector<ValueStore::WriteSpec> specs; | |
| 433 |
1/2✓ Branch 2 taken 54 times.
✗ Branch 3 not taken.
|
54 | specs.reserve(static_cast<size_t>(keys.Size())); |
| 434 |
2/2✓ Branch 1 taken 432 times.
✓ Branch 2 taken 54 times.
|
486 | for (int i = 0; i < keys.Size(); ++i) { |
| 435 |
1/2✓ Branch 1 taken 432 times.
✗ Branch 2 not taken.
|
432 | specs.push_back(ValueStore::WriteSpec{data + i * value_size, value_size}); |
| 436 | } | ||
| 437 |
1/2✓ Branch 2 taken 54 times.
✗ Branch 3 not taken.
|
54 | std::vector<uint64_t> handles = value_store_->BatchAllocAndWrite(specs); |
| 438 |
1/2✗ Branch 2 not taken.
✓ Branch 3 taken 54 times.
|
54 | if (handles.size() != static_cast<size_t>(keys.Size())) { |
| 439 | ✗ | LOG(FATAL) << "KVEngine::BulkLoad allocation result size mismatch"; | |
| 440 | } | ||
| 441 |
2/2✓ Branch 1 taken 432 times.
✓ Branch 2 taken 54 times.
|
486 | for (int i = 0; i < keys.Size(); ++i) { |
| 442 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 432 times.
|
432 | if (handles[static_cast<size_t>(i)] == kValueHandleNone) { |
| 443 | ✗ | LOG(FATAL) << "KVEngine bulk value allocation failed, key=" << keys[i] | |
| 444 | ✗ | << " size=" << value_size; | |
| 445 | } | ||
| 446 | } | ||
| 447 |
1/2✓ Branch 3 taken 54 times.
✗ Branch 4 not taken.
|
54 | index_->BatchPut(keys, handles.data(), 0); |
| 448 | 54 | } | |
| 449 | |||
| 450 | ✗ | void Util() override { | |
| 451 | ✗ | LOG(INFO) << "KVEngine index utilization=" << index_->Utilization() | |
| 452 | ✗ | << " value=" << value_store_->GetInfo(); | |
| 453 | ✗ | } | |
| 454 | |||
| 455 | ✗ | void DebugInfo() const override { | |
| 456 | ✗ | index_->DebugInfo(); | |
| 457 | ✗ | LOG(INFO) << value_store_->GetInfo(); | |
| 458 | ✗ | } | |
| 459 | |||
| 460 | ✗ | std::string ExtraResultFields() const override { | |
| 461 | ✗ | return value_store_ ? value_store_->ExtraResultFields() : ""; | |
| 462 | } | ||
| 463 | |||
| 464 | private: | ||
| 465 | 354022 | void PutInternal(uint64_t key, | |
| 466 | const void* data, | ||
| 467 | size_t size, | ||
| 468 | unsigned tid, | ||
| 469 | bool emit_fence) { | ||
| 470 | (void)tid; | ||
| 471 | (void)emit_fence; | ||
| 472 | 354022 | Value_t new_handle = value_store_->AllocAndWrite(data, size); | |
| 473 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 354022 times.
|
354022 | if (new_handle == kValueHandleNone) { |
| 474 | ✗ | LOG(FATAL) << "KVEngine value allocation failed, key=" << key | |
| 475 | ✗ | << " size=" << size; | |
| 476 | return; | ||
| 477 | } | ||
| 478 | 354022 | Value_t old_handle = index_->Put(key, new_handle, tid); | |
| 479 |
2/2✓ Branch 0 taken 156442 times.
✓ Branch 1 taken 197580 times.
|
354022 | if (old_handle != kValueHandleNone) { |
| 480 | 156442 | value_store_->Retire(old_handle); | |
| 481 | } | ||
| 482 | } | ||
| 483 | |||
| 484 | BaseKVConfig config_; | ||
| 485 | std::unique_ptr<Index> index_; | ||
| 486 | std::unique_ptr<ValueStore> value_store_; | ||
| 487 | int num_threads_ = 0; | ||
| 488 | size_t default_value_size_hint_ = 0; | ||
| 489 | }; | ||
| 490 | |||
| 491 | FACTORY_REGISTER( | ||
| 492 | BaseKV, KVEngineComposite, KVEngineComposite, const BaseKVConfig&); | ||
| 493 |