GCC Code Coverage Report


Directory: src/
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 100.0% 48 / 0 / 48
Functions: 100.0% 6 / 0 / 6
Branches: 50.0% 14 / 0 / 28

ps/grpc/grpc_ps_client.h
Line Branch Exec Source
1 #pragma once
2
3 #include <cstdint>
4 #include <future>
5 #include <mutex>
6 #include <string>
7 #include <unordered_map>
8 #include <vector>
9
10 #include "base/array.h"
11 #include "base/flatc.h"
12 #include "base/init.h"
13 #include "base/json.h"
14 #include "ps/base/base_client.h"
15 #include "ps/base/parameters.h"
16 #include "ps.grpc.pb.h"
17 #include "ps.pb.h"
18 #include "base/tensor.h"
19
20 using grpc::Channel;
21 using grpc::ClientContext;
22 using grpc::Status;
23 using recstoreps::CommandRequest;
24 using recstoreps::CommandResponse;
25 using recstoreps::GetParameterRequest;
26 using recstoreps::GetParameterResponse;
27 using recstoreps::PSCommand;
28 using recstoreps::PutParameterRequest;
29 using recstoreps::PutParameterResponse;
30
31 using base::ConstArray;
32 using json = nlohmann::json;
33
34 static const int MAX_PARAMETER_BATCH = 2000;
35
36 struct PrefetchBatch {
37 112 PrefetchBatch(int request_num) {
38 112 batch_size_ = request_num;
39
1/2
✓ Branch 1 taken 112 times.
✗ Branch 2 not taken.
112 key_sizes_.resize(request_num);
40
1/2
✓ Branch 1 taken 112 times.
✗ Branch 2 not taken.
112 status_.resize(request_num);
41
1/2
✓ Branch 1 taken 112 times.
✗ Branch 2 not taken.
112 contexts_.resize(request_num);
42
1/2
✓ Branch 1 taken 112 times.
✗ Branch 2 not taken.
112 requests_.resize(request_num);
43
1/2
✓ Branch 1 taken 112 times.
✗ Branch 2 not taken.
112 responses_.resize(request_num);
44
1/2
✓ Branch 1 taken 112 times.
✗ Branch 2 not taken.
112 response_readers_.resize(request_num);
45
1/2
✓ Branch 1 taken 112 times.
✗ Branch 2 not taken.
112 cqs_ = std::make_unique<grpc::CompletionQueue>();
46 112 completed_count_ = 0;
47 112 }
48
49 112 PrefetchBatch(PrefetchBatch&& other) noexcept
50 112 : key_sizes_(std::move(other.key_sizes_)),
51 112 status_(std::move(other.status_)),
52 112 contexts_(std::move(other.contexts_)),
53 112 requests_(std::move(other.requests_)),
54 112 responses_(std::move(other.responses_)),
55 112 response_readers_(std::move(other.response_readers_)),
56 112 batch_size_(other.batch_size_),
57 112 cqs_(std::move(other.cqs_)),
58 112 completed_count_(other.completed_count_) {
59 112 other.batch_size_ = 0;
60 112 }
61 PrefetchBatch(const PrefetchBatch&) = delete;
62 PrefetchBatch& operator=(const PrefetchBatch&) = delete;
63
64 std::vector<int> key_sizes_;
65 std::vector<Status> status_;
66 std::vector<std::unique_ptr<ClientContext>> contexts_;
67 std::vector<GetParameterRequest> requests_;
68 std::vector<GetParameterResponse> responses_;
69 std::vector<
70 std::unique_ptr<grpc::ClientAsyncResponseReader<GetParameterResponse>>>
71 response_readers_;
72
73 int batch_size_;
74 int completed_count_;
75 std::unique_ptr<grpc::CompletionQueue> cqs_;
76 };
77
78 struct PrewriteBatch {
79 8 PrewriteBatch(int request_num) {
80 8 batch_size_ = request_num;
81
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 key_sizes_.resize(request_num);
82
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 status_.resize(request_num);
83
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 contexts_.resize(request_num);
84
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 requests_.resize(request_num);
85
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 responses_.resize(request_num);
86
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 response_readers_.resize(request_num);
87
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 cqs_ = std::make_unique<grpc::CompletionQueue>();
88 8 completed_count_ = 0;
89 8 }
90
91 8 PrewriteBatch(PrewriteBatch&& other) noexcept
92 8 : key_sizes_(std::move(other.key_sizes_)),
93 8 status_(std::move(other.status_)),
94 8 contexts_(std::move(other.contexts_)),
95 8 requests_(std::move(other.requests_)),
96 8 responses_(std::move(other.responses_)),
97 8 response_readers_(std::move(other.response_readers_)),
98 8 batch_size_(other.batch_size_),
99 8 cqs_(std::move(other.cqs_)),
100 8 completed_count_(other.completed_count_) {
101 8 other.batch_size_ = 0;
102 8 }
103 PrewriteBatch(const PrewriteBatch&) = delete;
104 PrewriteBatch& operator=(const PrewriteBatch&) = delete;
105
106 std::vector<int> key_sizes_;
107 std::vector<Status> status_;
108 std::vector<std::unique_ptr<ClientContext>> contexts_;
109 std::vector<PutParameterRequest> requests_;
110 std::vector<PutParameterResponse> responses_;
111 std::vector<
112 std::unique_ptr<grpc::ClientAsyncResponseReader<PutParameterResponse>>>
113 response_readers_;
114
115 int batch_size_;
116 int completed_count_;
117 std::unique_ptr<grpc::CompletionQueue> cqs_;
118 };
119
120 class GRPCParameterClient : public recstore::BasePSClient {
121 public:
122 // New constructor with JSON config
123 explicit GRPCParameterClient(json config);
124
125 // Legacy constructor for backward compatibility
126 explicit GRPCParameterClient(const std::string& host, int port, int shard);
127
128 78 ~GRPCParameterClient() {}
129
130 // BasePSClient pure virtual implementations
131 virtual int
132 GetParameter(const base::ConstArray<uint64_t>& keys, float* values) override;
133
134 int AsyncGetParameter(const base::ConstArray<uint64_t>& keys,
135 float* values) override;
136
137 int PutParameter(const base::ConstArray<uint64_t>& keys,
138 const std::vector<std::vector<float>>& values) override;
139
140 void Command(recstore::PSCommand command) override;
141
142 // Legacy API methods
143 int GetParameter(const base::ConstArray<uint64_t>& keys,
144 std::vector<std::vector<float>>* values);
145 bool GetParameter(const base::ConstArray<unsigned int>& keys,
146 std::vector<std::vector<float>>* values);
147
148 inline int shard() const { return shard_; }
149
150 bool ClearPS();
151
152 bool LoadFakeData(int64_t data);
153
154 // Write n bytes of random floats to storage at key 0. n must be a positive
155 // multiple of sizeof(float).
156 bool DumpFakeData(int64_t n);
157
158 bool LoadCkpt(const std::vector<std::string>& model_config_path,
159 const std::vector<std::string>& emb_file_path);
160
161 bool PutParameter(const std::vector<uint64_t>& keys,
162 const std::vector<std::vector<float>>& values);
163
164 int UpdateParameter(const std::string& table_name,
165 const base::ConstArray<uint64_t>& keys,
166 const std::vector<std::vector<float>>* grads);
167 int UpdateParameterFlat(const std::string& table_name,
168 const base::ConstArray<uint64_t>& keys,
169 const float* grads,
170 int64_t num_rows,
171 int64_t embedding_dim) override;
172
173 int InitEmbeddingTable(const std::string& table_name,
174 const recstore::EmbeddingTableConfig& config);
175
176 uint64_t PrefetchParameter(const base::ConstArray<uint64_t>& keys);
177 bool IsPrefetchDone(uint64_t prefetch_id);
178 void WaitForPrefetch(uint64_t prefetch_id);
179 bool GetPrefetchResult(uint64_t prefetch_id,
180 std::vector<std::vector<float>>* values);
181 bool GetPrefetchResultFlat(uint64_t prefetch_id,
182 std::vector<float>* values,
183 int64_t* num_rows,
184 int64_t embedding_dim) override;
185 // Embeddings are vectors here; Get(float*) uses a flat buffer for legacy
186 // callers
187 virtual uint64_t EmbWriteAsync(const base::ConstArray<uint64_t>& keys,
188 const std::vector<std::vector<float>>& values);
189 virtual bool IsWriteDone(uint64_t write_id);
190 virtual void WaitForWrite(uint64_t write_id);
191
192 protected:
193 40 bool Initialize() { return true; }
194 std::string host_;
195 int port_;
196 int shard_;
197 int nr_clients_;
198 std::vector<float> cache_;
199 std::vector<int32_t> offset_;
200 std::vector<int> get_param_key_sizes_;
201 std::vector<Status> get_param_status_;
202 std::vector<GetParameterRequest> get_param_requests_;
203 std::vector<GetParameterResponse> get_param_responses_;
204 std::vector<std::unique_ptr<grpc::ClientContext>> get_param_contexts_;
205 std::vector<
206 std::unique_ptr<grpc::ClientAsyncResponseReader<GetParameterResponse>>>
207 get_param_resonse_readers_;
208 std::shared_ptr<Channel> channel_;
209 std::vector<std::unique_ptr<recstoreps::ParameterService::Stub>> stubs_;
210 std::unique_ptr<grpc::CompletionQueue> cq_;
211
212 private:
213 std::mutex prefetch_mu_;
214 std::unordered_map<uint64_t, struct PrefetchBatch> prefetch_batches_;
215 std::unordered_map<uint64_t, struct PrewriteBatch> prewrite_batches_;
216 // start from 1
217 uint64_t next_prefetch_id_ = 1;
218 uint64_t next_prewrite_id_ = 1;
219 };
220