GCC Code Coverage Report


Directory: src/
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 33.5% 128 / 0 / 382
Functions: 23.1% 9 / 0 / 39
Branches: 21.3% 136 / 0 / 638

framework/pytorch/op_torch.cc
Line Branch Exec Source
1 #include <torch/extension.h>
2
3 #include <cstdlib>
4 #include <cstring>
5 #include <iostream>
6 #include <chrono>
7 #include <mutex>
8 #include <string>
9 #include <unordered_map>
10 #include <unistd.h>
11 #include "base/tensor.h"
12 #include "framework/op.h"
13 #include "ps/local_shm/local_shm_client.h"
14 // Log level: 0=ERROR, 1=WARNING, 2=INFO, 3=DEBUG
15 #include <glog/logging.h>
16
17 #ifdef RECSTORE_ENABLE_GPU_CACHE
18 # include "framework/gpu/gpu_embedding_cache.h"
19 #endif
20
21 #if __has_include(<cuda_runtime_api.h>)
22 # include <ATen/cuda/CUDAContext.h>
23 # include <c10/cuda/CUDAException.h>
24 # include <c10/cuda/CUDAGuard.h>
25 # include <cuda_runtime_api.h>
26 # define RECSTORE_HAS_CUDA_RUNTIME_API 1
27 #else
28 # define RECSTORE_HAS_CUDA_RUNTIME_API 0
29 #endif
30
31 namespace recstore {
32 namespace framework {
33
34 namespace {
35
36 bool IsLocalFastPathBackend(const std::string& backend) {
37 return backend == "local_shm" || backend == "hierkv";
38 }
39
40 enum LookupProfileIndex : std::size_t {
41 kLookupTotalMs = 0,
42 kLookupKeysStageMs,
43 kLookupSubmitMs,
44 kLookupWaitMs,
45 kLookupPayloadPinMs,
46 kLookupFallbackCopyMs,
47 kLookupValuesH2DEnqueueMs,
48 kLookupProfileSize,
49 };
50
51 enum UpdateProfileIndex : std::size_t {
52 kUpdateTotalMs = 0,
53 kUpdateKeysStageMs,
54 kUpdateGradsStageMs,
55 kUpdateShmCallMs,
56 kUpdateStageWaitMs,
57 kUpdateProfileSize,
58 };
59
60 thread_local std::vector<double>
61 g_last_local_lookup_flat_profile(kLookupProfileSize, 0.0);
62 thread_local std::vector<double>
63 g_last_local_update_flat_profile(kUpdateProfileSize, 0.0);
64
65 inline std::chrono::steady_clock::time_point SteadyNow() {
66 return std::chrono::steady_clock::now();
67 }
68
69 inline double ElapsedMs(std::chrono::steady_clock::time_point start) {
70 return std::chrono::duration_cast<std::chrono::duration<double, std::milli>>(
71 SteadyNow() - start)
72 .count();
73 }
74
75 inline void ResetLocalLookupFlatProfile() {
76 std::fill(g_last_local_lookup_flat_profile.begin(),
77 g_last_local_lookup_flat_profile.end(),
78 0.0);
79 }
80
81 inline void ResetLocalUpdateFlatProfile() {
82 std::fill(g_last_local_update_flat_profile.begin(),
83 g_last_local_update_flat_profile.end(),
84 0.0);
85 }
86
87 #ifdef RECSTORE_ENABLE_GPU_CACHE
88 constexpr int64_t kGpuCacheBypassMinRows = 1024;
89 constexpr int kGpuCacheLowHitLimit = 1;
90 constexpr double kGpuCacheLowHitRatio = 0.05;
91 thread_local int g_gpu_cache_low_hit_streak = 0;
92 thread_local bool g_gpu_cache_lookup_bypassed = false;
93 thread_local bool g_gpu_cache_lookup_bypass_enabled = true;
94
95 void SafeClearGpuCacheNoThrow();
96
97 void ResetGpuCacheBypassState() {
98 g_gpu_cache_low_hit_streak = 0;
99 g_gpu_cache_lookup_bypassed = false;
100 }
101
102 bool ShouldBypassGpuCacheLookup(int64_t num_keys) {
103 return g_gpu_cache_lookup_bypass_enabled &&
104 num_keys >= kGpuCacheBypassMinRows &&
105 g_gpu_cache_low_hit_streak >= kGpuCacheLowHitLimit;
106 }
107
108 void RecordGpuCacheLookupOutcome(
109 int64_t num_keys, double hit_count, double request_count) {
110 if (num_keys < kGpuCacheBypassMinRows || request_count <= 0.0) {
111 return;
112 }
113 const double hit_ratio = hit_count / request_count;
114 if (hit_ratio < kGpuCacheLowHitRatio) {
115 ++g_gpu_cache_low_hit_streak;
116 } else {
117 g_gpu_cache_low_hit_streak = 0;
118 g_gpu_cache_lookup_bypassed = false;
119 }
120 }
121
122 bool ShouldBypassGpuCacheMaintenance(int64_t num_keys) {
123 return g_gpu_cache_lookup_bypass_enabled &&
124 num_keys >= kGpuCacheBypassMinRows && g_gpu_cache_lookup_bypassed;
125 }
126
127 void MarkGpuCacheLookupBypassed() {
128 if (!g_gpu_cache_lookup_bypassed) {
129 SafeClearGpuCacheNoThrow();
130 g_gpu_cache_low_hit_streak = kGpuCacheLowHitLimit;
131 }
132 g_gpu_cache_lookup_bypassed = true;
133 }
134
135 void EnsureGpuCacheSafeForLookup() {
136 if (g_gpu_cache_lookup_bypassed) {
137 SafeClearGpuCacheNoThrow();
138 ResetGpuCacheBypassState();
139 }
140 }
141
142 void SafeClearGpuCacheNoThrow() {
143 try {
144 gpu::ClearGpuCache();
145 } catch (const std::exception& e) {
146 LOG(WARNING) << "Failed to clear GPU cache: " << e.what();
147 } catch (...) {
148 LOG(WARNING) << "Failed to clear GPU cache: unknown exception";
149 }
150 }
151
152 void SetGpuCacheLookupBypassEnabled(bool enabled) {
153 g_gpu_cache_lookup_bypass_enabled = enabled;
154 if (!enabled) {
155 ResetGpuCacheBypassState();
156 }
157 }
158
159 void MaintainGpuCacheAfterUpdateNoThrow(const torch::Tensor& keys,
160 const torch::Tensor& grads,
161 int64_t embedding_dim) {
162 (void)grads;
163 if (!gpu::IsGpuCacheEnabled()) {
164 return;
165 }
166 if (ShouldBypassGpuCacheMaintenance(keys.numel())) {
167 gpu::ResetLastGpuCacheProfile();
168 return;
169 }
170 if (gpu::CanUseGpuCache(keys, embedding_dim)) {
171 try {
172 gpu::InvalidateGpuCache(keys);
173 return;
174 } catch (const std::exception& e) {
175 LOG(WARNING) << "GPU cache invalidation failed after backend update "
176 "succeeded; clearing cache and continuing: "
177 << e.what();
178 } catch (...) {
179 LOG(WARNING) << "GPU cache invalidation failed after backend update "
180 "succeeded; clearing cache and continuing: "
181 << "unknown exception";
182 }
183 }
184 SafeClearGpuCacheNoThrow();
185 gpu::ResetLastGpuCacheProfile();
186 }
187 #endif
188
189 } // namespace
190
191 static inline base::RecTensor
192 108 ToRecTensor(const torch::Tensor& tensor, base::DataType dtype) {
193 108 std::vector<int64_t> shape;
194
3/4
✓ Branch 1 taken 270 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 162 times.
✓ Branch 4 taken 108 times.
270 for (int i = 0; i < tensor.dim(); ++i) {
195
2/4
✓ Branch 1 taken 162 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 162 times.
✗ Branch 5 not taken.
162 shape.push_back(tensor.size(i));
196 }
197
2/4
✓ Branch 1 taken 108 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 108 times.
✗ Branch 5 not taken.
216 return base::RecTensor(const_cast<void*>(tensor.data_ptr()), shape, dtype);
198 108 }
199
200 static torch::TensorOptions PinnedCpuOptions(torch::ScalarType dtype) {
201 return torch::TensorOptions()
202 .device(torch::kCPU)
203 .dtype(dtype)
204 .pinned_memory(true);
205 }
206
207 static torch::Tensor StageCudaTensorToPinnedCpu(const torch::Tensor& tensor,
208 torch::ScalarType dtype) {
209 auto cpu_tensor = torch::empty(tensor.sizes(), PinnedCpuOptions(dtype));
210 cpu_tensor.copy_(tensor.to(dtype), /*non_blocking=*/false);
211 return cpu_tensor;
212 }
213
214 static torch::Tensor
215 StageCudaTensorToPinnedCpuAsyncNoCast(const torch::Tensor& tensor) {
216 auto cpu_tensor =
217 torch::empty(tensor.sizes(), PinnedCpuOptions(tensor.scalar_type()));
218 cpu_tensor.copy_(tensor, /*non_blocking=*/true);
219 return cpu_tensor;
220 }
221
222 static void SynchronizeCurrentCudaStreamForTensor(const torch::Tensor& tensor) {
223 #if RECSTORE_HAS_CUDA_RUNTIME_API
224 if (!tensor.is_cuda()) {
225 return;
226 }
227 c10::cuda::CUDAGuard device_guard(tensor.device());
228 C10_CUDA_CHECK(
229 cudaStreamSynchronize(at::cuda::getCurrentCUDAStream().stream()));
230 #else
231 (void)tensor;
232 #endif
233 }
234
235 static bool EnsurePinnedLocalShmPayload(const void* ptr, std::size_t bytes) {
236 #if !RECSTORE_HAS_CUDA_RUNTIME_API
237 (void)ptr;
238 (void)bytes;
239 return false;
240 #else
241 if (ptr == nullptr || bytes == 0) {
242 return false;
243 }
244 const long page_size = ::sysconf(_SC_PAGESIZE);
245 if (page_size <= 0) {
246 return false;
247 }
248 const std::size_t page_bytes = static_cast<std::size_t>(page_size);
249 const uintptr_t raw_begin = reinterpret_cast<uintptr_t>(ptr);
250 const uintptr_t raw_end = raw_begin + bytes;
251 const uintptr_t page_begin =
252 raw_begin & ~(static_cast<uintptr_t>(page_bytes) - 1U);
253 const uintptr_t page_end =
254 (raw_end + page_bytes - 1U) & ~(static_cast<uintptr_t>(page_bytes) - 1U);
255 const std::size_t required_bytes =
256 static_cast<std::size_t>(page_end - page_begin);
257
258 static std::mutex mu;
259 static std::unordered_map<uintptr_t, std::size_t> registered_bytes_by_base;
260 std::lock_guard<std::mutex> guard(mu);
261 const std::size_t existing_bytes = registered_bytes_by_base[page_begin];
262 if (existing_bytes >= required_bytes) {
263 return true;
264 }
265
266 void* register_ptr = reinterpret_cast<void*>(page_begin + existing_bytes);
267 const std::size_t register_bytes = required_bytes - existing_bytes;
268 const cudaError_t err =
269 cudaHostRegister(register_ptr, register_bytes, cudaHostRegisterPortable);
270 if (err != cudaSuccess && err != cudaErrorHostMemoryAlreadyRegistered) {
271 LOG(WARNING) << "cudaHostRegister failed for local_shm payload: "
272 << cudaGetErrorString(err)
273 << " base=" << reinterpret_cast<void*>(page_begin)
274 << " bytes=" << required_bytes;
275 return false;
276 }
277 registered_bytes_by_base[page_begin] = required_bytes;
278 return true;
279 #endif
280 }
281
282 32 torch::Tensor emb_read_torch(const torch::Tensor& keys, int64_t embedding_dim) {
283
1/2
✓ Branch 1 taken 32 times.
✗ Branch 2 not taken.
32 bool is_cuda = keys.is_cuda();
284
1/2
✓ Branch 1 taken 32 times.
✗ Branch 2 not taken.
32 auto orig_device = keys.device();
285
2/6
✗ Branch 0 not taken.
✓ Branch 1 taken 32 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 6 taken 32 times.
✗ Branch 7 not taken.
32 torch::Tensor cpu_keys = is_cuda ? keys.cpu() : keys;
286
287
2/4
✓ Branch 1 taken 32 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 32 times.
32 TORCH_CHECK(cpu_keys.dim() == 1, "Keys tensor must be 1-dimensional");
288
2/4
✓ Branch 1 taken 32 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 32 times.
32 TORCH_CHECK(cpu_keys.scalar_type() == torch::kInt64,
289 "Keys tensor must have dtype int64");
290
2/4
✓ Branch 1 taken 32 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 32 times.
32 TORCH_CHECK(cpu_keys.is_contiguous(), "Keys tensor must be contiguous");
291
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 32 times.
32 TORCH_CHECK(embedding_dim > 0, "Embedding dimension must be positive");
292
293
1/2
✓ Branch 1 taken 32 times.
✗ Branch 2 not taken.
32 const int64_t num_keys = cpu_keys.size(0);
294
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 32 times.
32 if (num_keys == 0) {
295 return torch::empty(
296 {0, embedding_dim}, torch::TensorOptions().dtype(torch::kFloat32));
297 }
298
299
1/2
✓ Branch 1 taken 32 times.
✗ Branch 2 not taken.
32 auto op = GetKVClientOp();
300
301 auto cpu_values = torch::empty(
302
2/4
✓ Branch 2 taken 32 times.
✗ Branch 3 not taken.
✓ Branch 8 taken 32 times.
✗ Branch 9 not taken.
32 {num_keys, embedding_dim}, torch::TensorOptions().dtype(torch::kFloat32));
303
304
1/2
✓ Branch 1 taken 32 times.
✗ Branch 2 not taken.
32 base::RecTensor rec_keys = ToRecTensor(cpu_keys, base::DataType::UINT64);
305
1/2
✓ Branch 1 taken 32 times.
✗ Branch 2 not taken.
32 base::RecTensor rec_values = ToRecTensor(cpu_values, base::DataType::FLOAT32);
306
307
1/2
✓ Branch 2 taken 32 times.
✗ Branch 3 not taken.
32 op->EmbRead(rec_keys, rec_values);
308
309
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 32 times.
32 if (is_cuda) {
310 return cpu_values.to(orig_device);
311 }
312 32 return cpu_values;
313 32 }
314
315 static std::shared_ptr<KVClientOp> GetConcreteKVClientOp() {
316 auto op = GetKVClientOp();
317 auto kv_op = std::dynamic_pointer_cast<KVClientOp>(op);
318 TORCH_CHECK(kv_op != nullptr, "storage backend is not KVClientOp");
319 return kv_op;
320 }
321
322 static torch::Tensor BackendLocalLookupFlat(
323 const std::shared_ptr<KVClientOp>& kv_op,
324 const torch::Tensor& cpu_keys,
325 const torch::Device& result_device,
326 bool result_on_cuda,
327 int64_t embedding_dim,
328 const std::chrono::steady_clock::time_point& total_start,
329 bool record_profile = true) {
330 const int64_t num_keys = cpu_keys.size(0);
331 base::RecTensor rec_keys = ToRecTensor(cpu_keys, base::DataType::UINT64);
332 if (kv_op->CurrentPSBackend() != "local_shm") {
333 auto cpu_values =
334 result_on_cuda
335 ? torch::empty({num_keys, embedding_dim},
336 PinnedCpuOptions(torch::kFloat32))
337 : torch::empty({num_keys, embedding_dim},
338 torch::TensorOptions()
339 .device(torch::kCPU)
340 .dtype(torch::kFloat32));
341 base::RecTensor rec_values =
342 ToRecTensor(cpu_values, base::DataType::FLOAT32);
343 kv_op->LocalLookupFlat(rec_keys, rec_values);
344 if (record_profile) {
345 g_last_local_lookup_flat_profile[kLookupTotalMs] = ElapsedMs(total_start);
346 }
347 if (result_on_cuda) {
348 return cpu_values.to(result_device, /*non_blocking=*/true);
349 }
350 return cpu_values;
351 }
352
353 if (!result_on_cuda) {
354 auto cpu_values = torch::empty(
355 {num_keys, embedding_dim},
356 torch::TensorOptions().device(torch::kCPU).dtype(torch::kFloat32));
357 base::RecTensor rec_values =
358 ToRecTensor(cpu_values, base::DataType::FLOAT32);
359 kv_op->LocalLookupFlat(rec_keys, rec_values);
360 if (record_profile) {
361 g_last_local_lookup_flat_profile[kLookupTotalMs] = ElapsedMs(total_start);
362 }
363 return cpu_values;
364 }
365
366 LocalShmFlatGetHandle handle;
367 const auto submit_start = SteadyNow();
368 TORCH_CHECK(
369 kv_op->SubmitLocalLookupFlat(rec_keys, embedding_dim, &handle) == 0,
370 "Failed to submit local_shm flat lookup.");
371 if (record_profile) {
372 g_last_local_lookup_flat_profile[kLookupSubmitMs] = ElapsedMs(submit_start);
373 }
374 const auto wait_start = SteadyNow();
375 const int wait_ret = kv_op->WaitLocalLookupFlat(&handle);
376 if (record_profile) {
377 g_last_local_lookup_flat_profile[kLookupWaitMs] = ElapsedMs(wait_start);
378 }
379 if (wait_ret != 0) {
380 kv_op->ReleaseLocalLookupFlat(&handle);
381 TORCH_CHECK(false, "Failed to wait for local_shm flat lookup.");
382 }
383 const float* payload_values = handle.values;
384 const int64_t payload_rows = handle.num_rows;
385 const int64_t payload_dim = handle.embedding_dim;
386 const std::size_t payload_bytes =
387 static_cast<std::size_t>(handle.output_bytes);
388 const int64_t expected_bytes =
389 num_keys * embedding_dim * static_cast<int64_t>(sizeof(float));
390 if (payload_values == nullptr || payload_rows != num_keys ||
391 payload_dim != embedding_dim ||
392 static_cast<int64_t>(payload_bytes) != expected_bytes) {
393 kv_op->ReleaseLocalLookupFlat(&handle);
394 TORCH_CHECK(false,
395 "local_shm flat lookup returned unexpected payload metadata.");
396 }
397 const auto pin_start = SteadyNow();
398 const bool payload_is_pinned =
399 EnsurePinnedLocalShmPayload(payload_values, payload_bytes);
400 if (record_profile) {
401 g_last_local_lookup_flat_profile[kLookupPayloadPinMs] =
402 ElapsedMs(pin_start);
403 }
404 if (payload_is_pinned) {
405 try {
406 LocalShmFlatGetHandle handle_for_release = handle;
407 auto cpu_view = torch::from_blob(
408 const_cast<float*>(payload_values),
409 {num_keys, embedding_dim},
410 [kv_op, handle_for_release](void* /*unused*/) mutable {
411 kv_op->ReleaseLocalLookupFlat(&handle_for_release);
412 },
413 PinnedCpuOptions(torch::kFloat32));
414 const auto h2d_start = SteadyNow();
415 auto result = cpu_view.to(result_device, /*non_blocking=*/true);
416 if (record_profile) {
417 g_last_local_lookup_flat_profile[kLookupValuesH2DEnqueueMs] =
418 ElapsedMs(h2d_start);
419 g_last_local_lookup_flat_profile[kLookupTotalMs] =
420 ElapsedMs(total_start);
421 }
422 return result;
423 } catch (...) {
424 kv_op->ReleaseLocalLookupFlat(&handle);
425 throw;
426 }
427 }
428
429 auto cpu_values = torch::empty(
430 {num_keys, embedding_dim}, PinnedCpuOptions(torch::kFloat32));
431 const auto fallback_copy_start = SteadyNow();
432 std::memcpy(cpu_values.data_ptr<float>(), payload_values, payload_bytes);
433 if (record_profile) {
434 g_last_local_lookup_flat_profile[kLookupFallbackCopyMs] =
435 ElapsedMs(fallback_copy_start);
436 }
437 kv_op->ReleaseLocalLookupFlat(&handle);
438 const auto h2d_start = SteadyNow();
439 auto result = cpu_values.to(result_device, /*non_blocking=*/true);
440 if (record_profile) {
441 g_last_local_lookup_flat_profile[kLookupValuesH2DEnqueueMs] =
442 ElapsedMs(h2d_start);
443 g_last_local_lookup_flat_profile[kLookupTotalMs] = ElapsedMs(total_start);
444 }
445 return result;
446 }
447
448 torch::Tensor
449 local_lookup_flat_torch(const torch::Tensor& keys, int64_t embedding_dim) {
450 ResetLocalLookupFlatProfile();
451 #ifdef RECSTORE_ENABLE_GPU_CACHE
452 gpu::ResetLastGpuCacheProfile();
453 #endif
454 const auto total_start = SteadyNow();
455 const bool is_cuda = keys.is_cuda();
456 auto orig_device = keys.device();
457
458 TORCH_CHECK(keys.dim() == 1, "Keys tensor must be 1-dimensional");
459 TORCH_CHECK(keys.scalar_type() == torch::kInt64,
460 "Keys tensor must have dtype int64");
461 TORCH_CHECK(keys.is_contiguous(), "Keys tensor must be contiguous");
462 TORCH_CHECK(embedding_dim > 0, "Embedding dimension must be positive");
463
464 auto kv_op = GetConcreteKVClientOp();
465 TORCH_CHECK(IsLocalFastPathBackend(kv_op->CurrentPSBackend()),
466 "local_lookup_flat requires local_shm or hierkv backend, but "
467 "current backend is ",
468 kv_op->CurrentPSBackend());
469
470 const int64_t num_keys = keys.size(0);
471 if (num_keys == 0) {
472 return torch::empty(
473 {0, embedding_dim}, torch::TensorOptions().dtype(torch::kFloat32));
474 }
475
476 #ifdef RECSTORE_ENABLE_GPU_CACHE
477 const bool can_use_gpu_cache = gpu::CanUseGpuCache(keys, embedding_dim);
478 const bool bypass_gpu_cache_lookup =
479 can_use_gpu_cache && ShouldBypassGpuCacheLookup(num_keys);
480 if (bypass_gpu_cache_lookup) {
481 MarkGpuCacheLookupBypassed();
482 }
483 if (can_use_gpu_cache && !bypass_gpu_cache_lookup) {
484 EnsureGpuCacheSafeForLookup();
485 try {
486 auto cache_result = gpu::QueryGpuCache(keys, embedding_dim);
487 RecordGpuCacheLookupOutcome(
488 num_keys,
489 static_cast<double>(num_keys - cache_result.missing_count),
490 static_cast<double>(num_keys));
491 if (cache_result.missing_count == 0) {
492 g_last_local_lookup_flat_profile[kLookupTotalMs] =
493 ElapsedMs(total_start);
494 return cache_result.values;
495 }
496
497 const auto backend_start = SteadyNow();
498 auto miss_values = BackendLocalLookupFlat(
499 kv_op,
500 cache_result.missing_keys_cpu.contiguous(),
501 orig_device,
502 /*result_on_cuda=*/false,
503 embedding_dim,
504 total_start);
505 const double backend_ms = ElapsedMs(backend_start);
506 gpu::AddGpuCacheBackendLookupMs(backend_ms);
507 auto miss_keys_cuda =
508 cache_result.missing_keys_cpu.to(orig_device, /*non_blocking=*/false);
509 auto miss_values_cuda =
510 miss_values.to(orig_device, /*non_blocking=*/false);
511 gpu::FillGpuCache(miss_keys_cuda, miss_values_cuda);
512 gpu::ScatterMissValues(&cache_result.values,
513 cache_result.missing_positions_cpu,
514 miss_values_cuda);
515 g_last_local_lookup_flat_profile[kLookupTotalMs] = ElapsedMs(total_start);
516 return cache_result.values;
517 } catch (const std::exception& e) {
518 LOG(WARNING)
519 << "GPU cache lookup failed; clearing cache and falling back: "
520 << e.what();
521 SafeClearGpuCacheNoThrow();
522 gpu::ResetLastGpuCacheProfile();
523 } catch (...) {
524 LOG(WARNING)
525 << "GPU cache lookup failed; clearing cache and falling back: "
526 << "unknown exception";
527 SafeClearGpuCacheNoThrow();
528 gpu::ResetLastGpuCacheProfile();
529 }
530 }
531 #endif
532
533 torch::Tensor cpu_keys = keys;
534 if (is_cuda) {
535 const auto stage_start = SteadyNow();
536 cpu_keys = StageCudaTensorToPinnedCpu(keys, torch::kInt64);
537 g_last_local_lookup_flat_profile[kLookupKeysStageMs] =
538 ElapsedMs(stage_start);
539 }
540
541 return BackendLocalLookupFlat(
542 kv_op, cpu_keys, orig_device, is_cuda, embedding_dim, total_start);
543 }
544
545 // Async prefetch: returns a unique prefetch id (uint64_t)
546 2 int64_t emb_prefetch_torch(const torch::Tensor& keys) {
547
2/4
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
2 TORCH_CHECK(keys.dim() == 1, "Keys tensor must be 1-dimensional");
548
2/4
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
2 TORCH_CHECK(keys.scalar_type() == torch::kInt64,
549 "Keys tensor must have dtype int64");
550
2/4
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
2 TORCH_CHECK(keys.is_contiguous(), "Keys tensor must be contiguous");
551
552
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 auto op = GetKVClientOp();
553
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 torch::Tensor cpu_keys = keys;
554
2/4
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
2 if (keys.is_cuda()) {
555 cpu_keys = keys.cpu();
556 }
557
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 base::RecTensor rec_keys = ToRecTensor(cpu_keys, base::DataType::UINT64);
558 // Dummy values tensor (unused by backend prefetch implementation)
559
2/4
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 9 not taken.
2 auto dummy_vals = torch::empty({0, 0}, keys.options().dtype(torch::kFloat32));
560
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 base::RecTensor rec_vals = ToRecTensor(dummy_vals, base::DataType::FLOAT32);
561
1/2
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
2 uint64_t pid = op->EmbPrefetch(rec_keys, rec_vals);
562 2 return static_cast<int64_t>(pid);
563 2 }
564
565 // Wait for prefetch and return result tensor [N, embedding_dim] on CPU
566 torch::Tensor
567 2 emb_wait_result_torch(int64_t prefetch_id, int64_t embedding_dim) {
568
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
2 TORCH_CHECK(embedding_dim > 0, "Embedding dimension must be positive");
569
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 auto op = GetKVClientOp();
570
1/2
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
2 op->WaitForPrefetch(static_cast<uint64_t>(prefetch_id));
571 2 std::vector<float> flat_values;
572 2 int64_t L = 0;
573
1/2
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
2 op->GetPretchResultFlat(
574 static_cast<uint64_t>(prefetch_id), &flat_values, &L, embedding_dim);
575 auto options =
576
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU);
577
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 auto out = torch::empty({L, embedding_dim}, options);
578
3/6
✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
✗ Branch 6 not taken.
2 if (L > 0 && !flat_values.empty()) {
579
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 std::memcpy(out.data_ptr<float>(),
580 2 flat_values.data(),
581 2 static_cast<size_t>(L) * static_cast<size_t>(embedding_dim) *
582 sizeof(float));
583 }
584 4 return out;
585 2 }
586
587 void emb_update_torch(const torch::Tensor& keys, const torch::Tensor& grads) {
588 throw std::runtime_error(
589 "emb_update_torch is deprecated. Use the Python-based sparse "
590 "optimizer.");
591 }
592
593 2 void emb_update_table_torch(const std::string& table_name,
594 const torch::Tensor& keys,
595 const torch::Tensor& grads) {
596
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
2 TORCH_CHECK(!table_name.empty(), "table_name must be non-empty");
597
2/4
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
2 TORCH_CHECK(keys.dim() == 1, "Keys tensor must be 1-dimensional");
598
2/4
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
2 TORCH_CHECK(keys.scalar_type() == torch::kInt64,
599 "Keys tensor must have dtype int64");
600
2/4
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
2 TORCH_CHECK(keys.is_contiguous(), "Keys tensor must be contiguous");
601
602
2/4
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
2 TORCH_CHECK(grads.dim() == 2, "Grads tensor must be 2-dimensional");
603
2/4
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
2 TORCH_CHECK(grads.scalar_type() == torch::kFloat32,
604 "Grads tensor must have dtype float32");
605
2/4
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
2 TORCH_CHECK(grads.is_contiguous(), "Grads tensor must be contiguous");
606
3/6
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
2 TORCH_CHECK(keys.size(0) == grads.size(0),
607 "Keys and grads tensors must have the same number of entries");
608
609
2/4
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
2 if (keys.size(0) == 0) {
610 return;
611 }
612
613
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 auto op = GetKVClientOp();
614
615
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 torch::Tensor cpu_keys = keys;
616
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 torch::Tensor cpu_grads = grads;
617
2/4
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
2 if (keys.is_cuda()) {
618 cpu_keys = keys.cpu();
619 }
620
2/4
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
2 if (grads.is_cuda()) {
621 cpu_grads = grads.cpu();
622 }
623
624
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 base::RecTensor rec_keys = ToRecTensor(cpu_keys, base::DataType::UINT64);
625
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 base::RecTensor rec_grads = ToRecTensor(cpu_grads, base::DataType::FLOAT32);
626
627
1/2
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
2 op->EmbUpdate(table_name, rec_keys, rec_grads);
628 #ifdef RECSTORE_ENABLE_GPU_CACHE
629 MaintainGpuCacheAfterUpdateNoThrow(keys, grads, grads.size(1));
630 #endif
631 2 }
632
633 void local_update_flat_torch(const std::string& table_name,
634 const torch::Tensor& keys,
635 const torch::Tensor& grads) {
636 ResetLocalUpdateFlatProfile();
637 #ifdef RECSTORE_ENABLE_GPU_CACHE
638 gpu::ResetLastGpuCacheProfile();
639 #endif
640 const auto total_start = SteadyNow();
641 TORCH_CHECK(!table_name.empty(), "table_name must be non-empty");
642 TORCH_CHECK(keys.dim() == 1, "Keys tensor must be 1-dimensional");
643 TORCH_CHECK(keys.scalar_type() == torch::kInt64,
644 "Keys tensor must have dtype int64");
645 TORCH_CHECK(keys.is_contiguous(), "Keys tensor must be contiguous");
646
647 TORCH_CHECK(grads.dim() == 2, "Grads tensor must be 2-dimensional");
648 TORCH_CHECK(grads.scalar_type() == torch::kFloat32,
649 "Grads tensor must have dtype float32");
650 TORCH_CHECK(grads.is_contiguous(), "Grads tensor must be contiguous");
651 TORCH_CHECK(keys.size(0) == grads.size(0),
652 "Keys and grads tensors must have the same number of entries");
653
654 auto kv_op = GetConcreteKVClientOp();
655 TORCH_CHECK(IsLocalFastPathBackend(kv_op->CurrentPSBackend()),
656 "local_update_flat requires local_shm or hierkv backend, but "
657 "current backend is ",
658 kv_op->CurrentPSBackend());
659
660 if (keys.size(0) == 0) {
661 g_last_local_update_flat_profile[kUpdateTotalMs] = ElapsedMs(total_start);
662 return;
663 }
664
665 torch::Tensor cpu_keys = keys;
666 const bool can_async_stage_cuda =
667 (keys.is_cuda() || grads.is_cuda()) &&
668 (!keys.is_cuda() || !grads.is_cuda() || keys.device() == grads.device());
669 bool staged_cuda_async = false;
670 if (keys.is_cuda()) {
671 const auto keys_stage_start = SteadyNow();
672 if (can_async_stage_cuda) {
673 cpu_keys = StageCudaTensorToPinnedCpuAsyncNoCast(keys);
674 staged_cuda_async = true;
675 } else {
676 cpu_keys = StageCudaTensorToPinnedCpu(keys, torch::kInt64);
677 }
678 g_last_local_update_flat_profile[kUpdateKeysStageMs] =
679 ElapsedMs(keys_stage_start);
680 }
681 torch::Tensor cpu_grads = grads;
682 if (grads.is_cuda()) {
683 const auto grads_stage_start = SteadyNow();
684 if (can_async_stage_cuda) {
685 cpu_grads = StageCudaTensorToPinnedCpuAsyncNoCast(grads);
686 staged_cuda_async = true;
687 } else {
688 cpu_grads = StageCudaTensorToPinnedCpu(grads, torch::kFloat32);
689 }
690 g_last_local_update_flat_profile[kUpdateGradsStageMs] =
691 ElapsedMs(grads_stage_start);
692 }
693 if (staged_cuda_async) {
694 const auto stage_wait_start = SteadyNow();
695 SynchronizeCurrentCudaStreamForTensor(keys.is_cuda() ? keys : grads);
696 g_last_local_update_flat_profile[kUpdateStageWaitMs] =
697 ElapsedMs(stage_wait_start);
698 }
699
700 base::RecTensor rec_keys = ToRecTensor(cpu_keys, base::DataType::UINT64);
701 base::RecTensor rec_grads = ToRecTensor(cpu_grads, base::DataType::FLOAT32);
702
703 const auto shm_call_start = SteadyNow();
704 try {
705 kv_op->LocalUpdateFlat(table_name, rec_keys, rec_grads);
706 } catch (...) {
707 #ifdef RECSTORE_ENABLE_GPU_CACHE
708 if (gpu::IsGpuCacheEnabled()) {
709 SafeClearGpuCacheNoThrow();
710 gpu::ResetLastGpuCacheProfile();
711 }
712 #endif
713 throw;
714 }
715 g_last_local_update_flat_profile[kUpdateShmCallMs] =
716 ElapsedMs(shm_call_start);
717
718 #ifdef RECSTORE_ENABLE_GPU_CACHE
719 MaintainGpuCacheAfterUpdateNoThrow(keys, grads, grads.size(1));
720 #endif
721
722 g_last_local_update_flat_profile[kUpdateTotalMs] = ElapsedMs(total_start);
723 }
724
725 std::vector<double> get_last_local_lookup_flat_profile_torch() {
726 return g_last_local_lookup_flat_profile;
727 }
728
729 std::vector<double> get_last_local_update_flat_profile_torch() {
730 return g_last_local_update_flat_profile;
731 }
732
733 bool warmup_local_lookup_flat_cuda_region_torch() {
734 auto kv_op = GetConcreteKVClientOp();
735 const void* payload_base = nullptr;
736 std::size_t payload_bytes = 0;
737 if (!kv_op->GetLocalLookupFlatPayloadRegion(&payload_base, &payload_bytes)) {
738 return false;
739 }
740 return EnsurePinnedLocalShmPayload(payload_base, payload_bytes);
741 }
742
743 10 bool init_embedding_table_torch(const std::string& table_name,
744 int64_t num_embeddings,
745 int64_t embedding_dim) {
746
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 10 times.
10 TORCH_CHECK(!table_name.empty(), "table_name must be non-empty");
747
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 10 times.
10 TORCH_CHECK(num_embeddings > 0, "num_embeddings must be positive");
748
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 10 times.
10 TORCH_CHECK(embedding_dim > 0, "embedding_dim must be positive");
749
750 10 EmbeddingTableConfig cfg{};
751 10 cfg.num_embeddings = static_cast<uint64_t>(num_embeddings);
752 10 cfg.embedding_dim = static_cast<uint64_t>(embedding_dim);
753
754
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 auto op = GetKVClientOp();
755
1/2
✓ Branch 2 taken 10 times.
✗ Branch 3 not taken.
10 const bool ok = op->InitEmbeddingTable(table_name, cfg);
756 #ifdef RECSTORE_ENABLE_GPU_CACHE
757 if (ok && gpu::IsGpuCacheEnabled()) {
758 SafeClearGpuCacheNoThrow();
759 gpu::ResetLastGpuCacheProfile();
760 }
761 #endif
762 10 return ok;
763 10 }
764
765 18 void emb_write_torch(const torch::Tensor& keys, const torch::Tensor& values) {
766
2/4
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
18 TORCH_CHECK(keys.dim() == 1, "Keys tensor must be 1-dimensional");
767
2/4
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
18 TORCH_CHECK(keys.scalar_type() == torch::kInt64,
768 "Keys tensor must have dtype int64");
769
2/4
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
18 TORCH_CHECK(keys.is_contiguous(), "Keys tensor must be contiguous");
770
2/4
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
18 TORCH_CHECK(values.dim() == 2, "Values tensor must be 2-dimensional");
771
2/4
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
18 TORCH_CHECK(values.scalar_type() == torch::kFloat32,
772 "Values tensor must have dtype float32");
773
2/4
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
18 TORCH_CHECK(values.is_contiguous(), "Values tensor must be contiguous");
774
3/6
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 18 times.
18 TORCH_CHECK(keys.size(0) == values.size(0),
775 "Keys and Values tensors must have the same number of entries");
776
777
2/4
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
18 if (keys.size(0) == 0) {
778 return;
779 }
780
781
1/2
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
18 auto op = GetKVClientOp();
782
783
1/2
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
18 torch::Tensor cpu_keys = keys;
784
1/2
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
18 torch::Tensor cpu_values = values;
785
2/4
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
18 if (keys.is_cuda()) {
786 cpu_keys = keys.cpu();
787 }
788
2/4
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
18 if (values.is_cuda()) {
789 cpu_values = values.cpu();
790 }
791
792
1/2
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
18 base::RecTensor rec_keys = ToRecTensor(cpu_keys, base::DataType::UINT64);
793
1/2
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
18 base::RecTensor rec_values = ToRecTensor(cpu_values, base::DataType::FLOAT32);
794
795
1/2
✓ Branch 2 taken 18 times.
✗ Branch 3 not taken.
18 op->EmbWrite(rec_keys, rec_values);
796 #ifdef RECSTORE_ENABLE_GPU_CACHE
797 if (gpu::IsGpuCacheEnabled()) {
798 SafeClearGpuCacheNoThrow();
799 gpu::ResetLastGpuCacheProfile();
800 }
801 #endif
802 18 }
803
804 void set_ps_config_torch(const std::string& host, int64_t port) {
805 auto kv_op = GetConcreteKVClientOp();
806 kv_op->SetPSConfig(host, static_cast<int>(port));
807 }
808
809 void set_ps_backend_torch(const std::string& backend) {
810 auto kv_op = GetConcreteKVClientOp();
811 kv_op->SetPSBackend(backend);
812 }
813
814 std::string current_ps_backend_torch() {
815 auto kv_op = GetConcreteKVClientOp();
816 return kv_op->CurrentPSBackend();
817 }
818
819 bool enable_gpu_cache_torch(int64_t capacity, int64_t embedding_dim) {
820 #ifdef RECSTORE_ENABLE_GPU_CACHE
821 const bool enabled = gpu::EnableGpuCache(capacity, embedding_dim);
822 if (enabled) {
823 ResetGpuCacheBypassState();
824 }
825 return enabled;
826 #else
827 (void)capacity;
828 (void)embedding_dim;
829 return false;
830 #endif
831 }
832
833 void disable_gpu_cache_torch() {
834 #ifdef RECSTORE_ENABLE_GPU_CACHE
835 gpu::DisableGpuCache();
836 ResetGpuCacheBypassState();
837 #endif
838 }
839
840 22 void clear_gpu_cache_torch() {
841 #ifdef RECSTORE_ENABLE_GPU_CACHE
842 gpu::ClearGpuCache();
843 ResetGpuCacheBypassState();
844 #endif
845 22 }
846
847 void prefill_gpu_cache_torch(const torch::Tensor& keys,
848 const torch::Tensor& values) {
849 #ifdef RECSTORE_ENABLE_GPU_CACHE
850 TORCH_CHECK(keys.dim() == 1, "keys must be 1-dimensional");
851 TORCH_CHECK(keys.scalar_type() == torch::kInt64,
852 "keys must have dtype int64");
853 TORCH_CHECK(values.dim() == 2, "values must be 2-dimensional");
854 TORCH_CHECK(values.scalar_type() == torch::kFloat32,
855 "values must have dtype float32");
856 TORCH_CHECK(keys.size(0) == values.size(0),
857 "keys and values must have the same number of rows");
858 if (keys.numel() == 0) {
859 return;
860 }
861 TORCH_CHECK(keys.is_cuda() || values.is_cuda(),
862 "prefill_gpu_cache requires keys or values on CUDA");
863 const auto cache_device = values.is_cuda() ? values.device() : keys.device();
864 auto keys_cuda = keys.is_cuda() ? keys : keys.to(cache_device);
865 auto values_cuda = values.is_cuda() ? values : values.to(cache_device);
866 if (!keys_cuda.is_contiguous()) {
867 keys_cuda = keys_cuda.contiguous();
868 }
869 if (!values_cuda.is_contiguous()) {
870 values_cuda = values_cuda.contiguous();
871 }
872 gpu::FillGpuCache(keys_cuda, values_cuda);
873 #else
874 (void)keys;
875 (void)values;
876 #endif
877 }
878
879 void set_gpu_cache_lookup_bypass_enabled_torch(bool enabled) {
880 #ifdef RECSTORE_ENABLE_GPU_CACHE
881 SetGpuCacheLookupBypassEnabled(enabled);
882 #else
883 (void)enabled;
884 #endif
885 }
886
887 bool is_gpu_cache_lookup_bypass_enabled_torch() {
888 #ifdef RECSTORE_ENABLE_GPU_CACHE
889 return g_gpu_cache_lookup_bypass_enabled;
890 #else
891 return false;
892 #endif
893 }
894
895 bool is_gpu_cache_lookup_bypassed_torch() {
896 #ifdef RECSTORE_ENABLE_GPU_CACHE
897 return g_gpu_cache_lookup_bypassed;
898 #else
899 return false;
900 #endif
901 }
902
903 void reset_gpu_cache_bypass_state_torch() {
904 #ifdef RECSTORE_ENABLE_GPU_CACHE
905 ResetGpuCacheBypassState();
906 #endif
907 }
908
909 std::vector<double> get_last_gpu_cache_profile_torch() {
910 #ifdef RECSTORE_ENABLE_GPU_CACHE
911 const auto profile = gpu::GetLastGpuCacheProfile();
912 return {
913 profile.query_ms,
914 profile.backend_lookup_ms,
915 profile.fill_ms,
916 profile.update_ms,
917 profile.hit_count,
918 profile.invalidate_ms,
919 profile.request_count,
920 profile.miss_count,
921 };
922 #else
923 return {};
924 #endif
925 }
926
927 4 TORCH_LIBRARY(recstore_ops, m) {
928
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 m.def("emb_read", emb_read_torch);
929
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 m.def("local_lookup_flat", local_lookup_flat_torch);
930
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 m.def("emb_update", emb_update_torch);
931
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 m.def("emb_update_table", emb_update_table_torch);
932
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 m.def("local_update_flat", local_update_flat_torch);
933
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 m.def("init_embedding_table", init_embedding_table_torch);
934
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 m.def("emb_write", emb_write_torch);
935
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 m.def("emb_prefetch", emb_prefetch_torch);
936
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 m.def("emb_wait_result", emb_wait_result_torch);
937
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 m.def("set_ps_config", set_ps_config_torch);
938
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 m.def("set_ps_backend", set_ps_backend_torch);
939
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 m.def("current_ps_backend", current_ps_backend_torch);
940
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 m.def("get_last_local_lookup_flat_profile",
941 get_last_local_lookup_flat_profile_torch);
942
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 m.def("get_last_local_update_flat_profile",
943 get_last_local_update_flat_profile_torch);
944
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 m.def("warmup_local_lookup_flat_cuda_region",
945 warmup_local_lookup_flat_cuda_region_torch);
946
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 m.def("enable_gpu_cache", enable_gpu_cache_torch);
947
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 m.def("disable_gpu_cache", disable_gpu_cache_torch);
948
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 m.def("clear_gpu_cache", clear_gpu_cache_torch);
949
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 m.def("prefill_gpu_cache", prefill_gpu_cache_torch);
950
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 m.def("set_gpu_cache_lookup_bypass_enabled",
951 set_gpu_cache_lookup_bypass_enabled_torch);
952
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 m.def("is_gpu_cache_lookup_bypass_enabled",
953 is_gpu_cache_lookup_bypass_enabled_torch);
954
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 m.def("is_gpu_cache_lookup_bypassed", is_gpu_cache_lookup_bypassed_torch);
955
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 m.def("reset_gpu_cache_bypass_state", reset_gpu_cache_bypass_state_torch);
956
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 m.def("get_last_gpu_cache_profile", get_last_gpu_cache_profile_torch);
957 4 }
958
959 } // namespace framework
960 } // namespace recstore
961