GCC Code Coverage Report


Directory: src/
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 100.0% 16 / 0 / 16
Functions: 100.0% 3 / 0 / 3
Branches: 50.0% 7 / 0 / 14

ps/brpc/brpc_ps_client.h
Line Branch Exec Source
1 #pragma once
2
3 #include <brpc/channel.h>
4 #include <brpc/controller.h>
5 #include <butil/logging.h>
6
7 #include <cstdint>
8 #include <memory>
9 #include <string>
10 #include <unordered_map>
11 #include <vector>
12 #include <atomic>
13
14 #include "base/array.h"
15 #include "base/flatc.h"
16 #include "base/json.h"
17 #include "base/tensor.h"
18 #include "ps/base/base_client.h"
19 #include "ps/base/parameters.h"
20 #include "ps_brpc.pb.h"
21
22 using json = nlohmann::json;
23
24 static const int MAX_PARAMETER_BATCH_BRPC = 2000;
25
26 // Prefetch batch structure for bRPC
27 struct BrpcPrefetchBatch {
28 10 BrpcPrefetchBatch(int request_num) {
29 10 batch_size_ = request_num;
30
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 key_sizes_.resize(request_num);
31
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 responses_.resize(request_num);
32
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 controllers_.resize(request_num);
33 10 completed_count_ = 0;
34 10 }
35
36 BrpcPrefetchBatch(BrpcPrefetchBatch&& other) noexcept
37 : key_sizes_(std::move(other.key_sizes_)),
38 responses_(std::move(other.responses_)),
39 controllers_(std::move(other.controllers_)),
40 batch_size_(other.batch_size_),
41 completed_count_(other.completed_count_.load()) {
42 other.batch_size_ = 0;
43 }
44
45 BrpcPrefetchBatch(const BrpcPrefetchBatch&) = delete;
46 BrpcPrefetchBatch& operator=(const BrpcPrefetchBatch&) = delete;
47
48 std::vector<int> key_sizes_;
49 std::vector<recstoreps_brpc::GetParameterResponse> responses_;
50 std::vector<std::unique_ptr<brpc::Controller>> controllers_;
51 int batch_size_;
52 std::atomic<int> completed_count_;
53 };
54
55 struct BrpcPrewriteBatch {
56 8 BrpcPrewriteBatch(int request_num) {
57 8 batch_size_ = request_num;
58
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 key_sizes_.resize(request_num);
59
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 requests_.resize(request_num);
60
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 responses_.resize(request_num);
61
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 controllers_.resize(request_num);
62 8 completed_count_ = 0;
63 8 }
64
65 BrpcPrewriteBatch(BrpcPrewriteBatch&& other) noexcept
66 : key_sizes_(std::move(other.key_sizes_)),
67 requests_(std::move(other.requests_)),
68 responses_(std::move(other.responses_)),
69 controllers_(std::move(other.controllers_)),
70 batch_size_(other.batch_size_),
71 completed_count_(other.completed_count_.load()) {
72 other.batch_size_ = 0;
73 }
74
75 BrpcPrewriteBatch(const BrpcPrewriteBatch&) = delete;
76 BrpcPrewriteBatch& operator=(const BrpcPrewriteBatch&) = delete;
77
78 std::vector<int> key_sizes_;
79 std::vector<recstoreps_brpc::PutParameterRequest> requests_;
80 std::vector<recstoreps_brpc::PutParameterResponse> responses_;
81 std::vector<std::unique_ptr<brpc::Controller>> controllers_;
82 int batch_size_;
83 std::atomic<int> completed_count_;
84 };
85
86 class BRPCParameterClient : public recstore::BasePSClient {
87 public:
88 // New constructor with JSON config
89 explicit BRPCParameterClient(json config);
90
91 // Legacy constructor for backward compatibility
92 explicit BRPCParameterClient(const std::string& host, int port, int shard);
93
94 54 ~BRPCParameterClient() {}
95
96 // BasePSClient pure virtual implementations
97 virtual int
98 GetParameter(const base::ConstArray<uint64_t>& keys, float* values) override;
99
100 int AsyncGetParameter(const base::ConstArray<uint64_t>& keys,
101 float* values) override;
102
103 int PutParameter(const base::ConstArray<uint64_t>& keys,
104 const std::vector<std::vector<float>>& values) override;
105
106 void Command(recstore::PSCommand command) override;
107
108 // Legacy API methods
109 int GetParameter(const base::ConstArray<uint64_t>& keys,
110 std::vector<std::vector<float>>* values);
111
112 inline int shard() const { return shard_; }
113
114 bool ClearPS();
115
116 bool LoadFakeData(int64_t data);
117
118 bool DumpFakeData(int64_t n);
119
120 bool LoadCkpt(const std::vector<std::string>& model_config_path,
121 const std::vector<std::string>& emb_file_path);
122
123 bool PutParameter(const std::vector<uint64_t>& keys,
124 const std::vector<std::vector<float>>& values);
125
126 int UpdateParameter(const std::string& table_name,
127 const base::ConstArray<uint64_t>& keys,
128 const std::vector<std::vector<float>>* grads);
129 int UpdateParameterFlat(const std::string& table_name,
130 const base::ConstArray<uint64_t>& keys,
131 const float* grads,
132 int64_t num_rows,
133 int64_t embedding_dim) override;
134
135 int InitEmbeddingTable(const std::string& table_name,
136 const recstore::EmbeddingTableConfig& config);
137
138 // Prefetch API
139 uint64_t PrefetchParameter(const base::ConstArray<uint64_t>& keys);
140 bool IsPrefetchDone(uint64_t prefetch_id);
141 void WaitForPrefetch(uint64_t prefetch_id);
142 bool GetPrefetchResult(uint64_t prefetch_id,
143 std::vector<std::vector<float>>* values);
144 bool GetPrefetchResultFlat(uint64_t prefetch_id,
145 std::vector<float>* values,
146 int64_t* num_rows,
147 int64_t embedding_dim) override;
148
149 virtual uint64_t
150 EmbWriteAsync(const base::RecTensor& keys, const base::RecTensor& values);
151 virtual bool IsWriteDone(uint64_t write_id);
152 virtual void WaitForWrite(uint64_t write_id);
153
154 protected:
155 bool Initialize();
156
157 std::string host_;
158 int port_;
159 int shard_;
160 int timeout_ms_;
161 int max_retry_;
162
163 // bRPC channel
164 std::shared_ptr<brpc::Channel> channel_;
165
166 std::vector<float> cache_;
167 std::vector<int32_t> offset_;
168
169 private:
170 std::unordered_map<uint64_t, struct BrpcPrefetchBatch> prefetch_batches_;
171 std::unordered_map<uint64_t, struct BrpcPrewriteBatch> prewrite_batches_;
172 uint64_t next_prefetch_id_ = 1;
173 uint64_t next_prewrite_id_ = 1;
174 };
175