GCC Code Coverage Report


Directory: src/
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 100.0% 16 / 0 / 16
Functions: 100.0% 7 / 0 / 7
Branches: -% 0 / 0 / 0

framework/op.h
Line Branch Exec Source
1 #pragma once
2
3 #include "base/tensor.h"
4 #include "framework/common/op_runtime_support.h"
5 #include "ps/base/base_client.h"
6 #include <cstddef>
7 #include <memory>
8 #include <string>
9 #include <mutex>
10 #include <unordered_map>
11 #include <vector>
12
13 using base::RecTensor;
14
15 namespace recstore {
16 enum class InitStrategyType { Normal, Uniform, Xavier, Zero };
17 struct LocalShmFlatGetHandle;
18
19 struct InitStrategy {
20 InitStrategy() = delete;
21 InitStrategyType type;
22
23 // Optional fields depending on type
24 float mean = 0.0f;
25 float std = 1.0f;
26 float lower = -1.0f;
27 float upper = 1.0f;
28
29 12 InitStrategy(InitStrategyType t) : type(t) {}
30
31 2 static InitStrategy Normal(float mean, float std) {
32 2 InitStrategy s(InitStrategyType::Normal);
33 2 s.mean = mean;
34 2 s.std = std;
35 2 return s;
36 }
37
38 2 static InitStrategy Uniform(float lower, float upper) {
39 2 InitStrategy s(InitStrategyType::Uniform);
40 2 s.lower = lower;
41 2 s.upper = upper;
42 2 return s;
43 }
44
45 2 static InitStrategy Xavier() {
46 2 return InitStrategy(InitStrategyType::Xavier);
47 }
48 6 static InitStrategy Zero() { return InitStrategy(InitStrategyType::Zero); }
49 };
50 class CommonOp {
51 public:
52 // keys: uint64_t tensor with shape [N]
53 // values: emb.dtype tensor with shape [N, D]
54
55 34 CommonOp() = default;
56
57 virtual void EmbInit(const RecTensor& keys, const RecTensor& init_values) = 0;
58 virtual void EmbInit(const RecTensor& keys, const InitStrategy& strategy) = 0;
59
60 // Core KV APIs (sync)
61 virtual void
62 EmbRead(const RecTensor& keys, RecTensor& values) = 0; // sync read
63 virtual void
64 EmbWrite(const RecTensor& keys, const RecTensor& values) = 0; // sync write
65
66 virtual bool
67 EmbExists(const RecTensor& keys) = 0; // not urgent, optional existence check
68 virtual void
69 EmbDelete(const RecTensor& keys) = 0; // not urgent, optional deletion
70
71 // Optional Gradient Hook (can be omitted if optimizer is outside)
72 virtual void
73 EmbUpdate(const RecTensor& keys, const RecTensor& grads) = 0; // not urgent
74 virtual void EmbUpdate(const std::string& table_name,
75 const RecTensor& keys,
76 const RecTensor& grads) = 0;
77
78 virtual bool InitEmbeddingTable(const std::string& table_name,
79 const EmbeddingTableConfig& config) = 0;
80
81 // Prefetch & write (async)
82 virtual uint64_t
83 EmbPrefetch(const RecTensor& keys,
84 const RecTensor& values) = 0; // async prefetch, returns a unique
85 // ID to track the prefetch status.
86 virtual bool IsPrefetchDone(
87 uint64_t prefetch_id) = 0; // returns true if the prefetch identified by
88 // prefetch_id is complete.
89 virtual void WaitForPrefetch(
90 uint64_t prefetch_id) = 0; // blocks until the prefetch identified by
91 // prefetch_id is complete.
92 virtual void GetPretchResult(uint64_t prefetch_id,
93 std::vector<std::vector<float>>* values) = 0;
94 virtual void GetPretchResultFlat(
95 uint64_t prefetch_id,
96 std::vector<float>* values,
97 int64_t* num_rows,
98 int64_t embedding_dim) = 0;
99
100 virtual uint64_t
101 EmbWriteAsync(const RecTensor& keys,
102 const RecTensor& values) = 0; // async write, returns a unique
103 // ID to track the write status.
104 virtual bool
105 IsWriteDone(uint64_t write_id) = 0; // returns true if the asynchronous write
106 // identified by write_id is complete.
107 virtual void
108 WaitForWrite(uint64_t write_id) = 0; // blocks until the asynchronous write
109 // identified by write_id is complete.
110
111 // Persistence
112 virtual void SaveToFile(const std::string& path) = 0; // not urgent
113 virtual void LoadFromFile(const std::string& path) = 0; // not urgent
114
115 34 virtual ~CommonOp() = default;
116 };
117
118 class KVClientOp : public CommonOp {
119 public:
120 KVClientOp();
121
122 void EmbInit(const base::RecTensor& keys,
123 const base::RecTensor& init_values) override;
124 void EmbInit(const base::RecTensor& keys,
125 const InitStrategy& strategy) override;
126 void EmbRead(const base::RecTensor& keys, base::RecTensor& values) override;
127 void EmbWrite(const base::RecTensor& keys,
128 const base::RecTensor& values) override;
129 void EmbUpdate(const base::RecTensor& keys,
130 const base::RecTensor& grads) override;
131 void EmbUpdate(const std::string& table_name,
132 const base::RecTensor& keys,
133 const base::RecTensor& grads) override;
134 bool InitEmbeddingTable(const std::string& table_name,
135 const EmbeddingTableConfig& config) override;
136 bool EmbExists(const base::RecTensor& keys) override;
137 void EmbDelete(const base::RecTensor& keys) override;
138 uint64_t EmbPrefetch(const base::RecTensor& keys,
139 const base::RecTensor& values) override;
140 bool IsPrefetchDone(uint64_t prefetch_id) override;
141 void WaitForPrefetch(uint64_t prefetch_id) override;
142 void GetPretchResult(uint64_t prefetch_id,
143 std::vector<std::vector<float>>* values) override;
144 void GetPretchResultFlat(uint64_t prefetch_id,
145 std::vector<float>* values,
146 int64_t* num_rows,
147 int64_t embedding_dim) override;
148 uint64_t EmbWriteAsync(const base::RecTensor& keys,
149 const base::RecTensor& values) override;
150 bool IsWriteDone(uint64_t write_id) override;
151 void WaitForWrite(uint64_t write_id) override;
152 void SaveToFile(const std::string& path) override;
153 void LoadFromFile(const std::string& path) override;
154 void SetPSConfig(const std::string& host, int port);
155 void SetPSBackend(const std::string& backend);
156 std::string CurrentPSBackend() const;
157 void LocalLookupFlat(const base::RecTensor& keys, base::RecTensor& values);
158 int SubmitLocalLookupFlat(const base::RecTensor& keys,
159 int64_t embedding_dim,
160 LocalShmFlatGetHandle* handle);
161 int WaitLocalLookupFlat(LocalShmFlatGetHandle* handle);
162 void ReleaseLocalLookupFlat(LocalShmFlatGetHandle* handle);
163 bool GetLocalLookupFlatPayloadRegion(const void** base, std::size_t* bytes);
164 void LocalUpdateFlat(const std::string& table_name,
165 const base::RecTensor& keys,
166 const base::RecTensor& grads);
167
168 private:
169 int64_t embedding_dim_;
170 std::string ps_backend_name_ = "unknown";
171 static BasePSClient* ps_client_;
172 static std::unique_ptr<BasePSClient> ps_client_holder_;
173
174 #ifdef USE_FAKE_KVCLIENT
175 std::unordered_map<uint64_t, std::vector<float>> store_;
176 std::mutex mtx_;
177 float learning_rate_;
178 std::unordered_map<uint64_t, std::vector<std::vector<float>>>
179 prefetch_results_;
180 uint64_t next_prefetch_id_ = 1;
181 #endif
182 };
183
184 std::shared_ptr<CommonOp> GetKVClientOp();
185
186 } // namespace recstore
187