GCC Code Coverage Report


Directory: src/
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 100.0% 1 / 0 / 1
Functions: 100.0% 1 / 0 / 1
Branches: -% 0 / 0 / 0

ps/brpc/dist_brpc_ps_client.h
Line Branch Exec Source
1 #pragma once
2
3 #include <future>
4 #include <memory>
5 #include <mutex>
6 #include <string>
7 #include <unordered_map>
8 #include <vector>
9
10 #include "base/array.h"
11 #include "base/hash.h"
12 #include "base/json.h"
13 #include "base/log.h"
14 #include "ps/base/base_client.h"
15 #include "brpc_ps_client.h"
16
17 using json = nlohmann::json;
18
19 namespace recstore {
20
21 /**
22 * @brief Distributed bRPC parameter-server client
23 *
24 * Many-to-many connections; routes keys to shards via a hash function.
25 * Server list and hash method come from JSON config.
26 */
27 class DistributedBRPCParameterClient : public BasePSClient {
28 public:
29 explicit DistributedBRPCParameterClient(json config);
30
31 ~DistributedBRPCParameterClient();
32
33 // BasePSClient pure virtual implementations
34 int GetParameter(const base::ConstArray<uint64_t>& keys,
35 float* values) override;
36
37 int AsyncGetParameter(const base::ConstArray<uint64_t>& keys,
38 float* values) override;
39
40 int PutParameter(const base::ConstArray<uint64_t>& keys,
41 const std::vector<std::vector<float>>& values) override;
42
43 void Command(PSCommand command) override;
44
45 int UpdateParameter(const std::string& table_name,
46 const base::ConstArray<uint64_t>& keys,
47 const std::vector<std::vector<float>>* grads) override;
48 int UpdateParameterFlat(const std::string& table_name,
49 const base::ConstArray<uint64_t>& keys,
50 const float* grads,
51 int64_t num_rows,
52 int64_t embedding_dim) override;
53
54 int InitEmbeddingTable(const std::string& table_name,
55 const recstore::EmbeddingTableConfig& config) override;
56
57 // Prefetch API (stubbed for distributed client)
58 uint64_t PrefetchParameter(const base::ConstArray<uint64_t>& keys) override;
59 bool IsPrefetchDone(uint64_t prefetch_id) override;
60 void WaitForPrefetch(uint64_t prefetch_id) override;
61 bool GetPrefetchResult(uint64_t prefetch_id,
62 std::vector<std::vector<float>>* values) override;
63 bool GetPrefetchResultFlat(uint64_t prefetch_id,
64 std::vector<float>* values,
65 int64_t* num_rows,
66 int64_t embedding_dim) override;
67
68 // Extended API
69 bool GetParameter(const base::ConstArray<uint64_t>& keys,
70 std::vector<std::vector<float>>* values);
71
72 bool ClearPS();
73
74 bool LoadFakeData(int64_t n);
75 bool DumpFakeData(int64_t n);
76 bool LoadCkpt(const std::vector<std::string>& model_config_path,
77 const std::vector<std::string>& emb_file_path);
78
79 4 int shard_count() const { return num_shards_; }
80
81 private:
82 struct DistPrefetchShardState {
83 int shard_id = -1;
84 int client_index = -1;
85 std::vector<size_t> original_indices;
86 std::vector<uint64_t> child_prefetch_ids;
87 std::vector<int> chunk_sizes;
88 };
89
90 struct DistPrefetchState {
91 size_t total_keys = 0;
92 std::vector<DistPrefetchShardState> shard_states;
93 };
94
95 int GetShardId(uint64_t key) const;
96
97 void InitializeClients();
98
99 void
100 PartitionKeys(const base::ConstArray<uint64_t>& keys,
101 std::vector<std::vector<uint64_t>>& partitioned_keys) const;
102
103 void MergeResults(
104 const base::ConstArray<uint64_t>& keys,
105 const std::vector<std::vector<std::vector<float>>>& partitioned_results,
106 std::vector<std::vector<float>>* values) const;
107
108 void MergeResultsToArray(
109 const base::ConstArray<uint64_t>& keys,
110 const std::vector<std::vector<std::vector<float>>>& partitioned_results,
111 float* values) const;
112
113 private:
114 // Config
115 int num_shards_;
116 int max_keys_per_request_;
117 std::string hash_method_;
118
119 // Per-server entries
120 struct ServerConfig {
121 std::string host;
122 int port;
123 int shard;
124 };
125 std::vector<ServerConfig> server_configs_;
126
127 // bRPC clients (one per server entry)
128 std::vector<std::unique_ptr<BRPCParameterClient>> clients_;
129
130 // Logical shard id -> index in clients_
131 std::unordered_map<int, int> shard_to_client_index_;
132
133 // Partition buffers (reused)
134 mutable std::vector<std::vector<uint64_t>> partitioned_key_buffer_;
135 mutable std::vector<std::vector<size_t>> key_index_mapping_;
136
137 std::mutex prefetch_mu_;
138 std::unordered_map<uint64_t, std::shared_ptr<DistPrefetchState>>
139 prefetch_states_;
140 uint64_t next_prefetch_id_ = 1;
141 };
142
143 } // namespace recstore
144