GCC Code Coverage Report


Directory: src/
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 100.0% 1 / 0 / 1
Functions: 100.0% 1 / 0 / 1
Branches: -% 0 / 0 / 0

ps/grpc/dist_grpc_ps_client.h
Line Branch Exec Source
1 #pragma once
2
3 #include <future>
4 #include <memory>
5 #include <mutex>
6 #include <string>
7 #include <unordered_map>
8 #include <vector>
9
10 #include "base/array.h"
11 #include "base/hash.h"
12 #include "base/json.h"
13 #include "base/log.h"
14 #include "ps/base/base_client.h"
15 #include "grpc_ps_client.h"
16
17 using json = nlohmann::json;
18
19 namespace recstore {
20
21 /**
22 * @brief Distributed gRPC parameter-server client
23 *
24 * Many-to-many connections; routes keys to shards via a hash function.
25 * Server list and hash method come from JSON config.
26 */
27 class DistributedGRPCParameterClient : public BasePSClient {
28 public:
29 explicit DistributedGRPCParameterClient(json config);
30
31 ~DistributedGRPCParameterClient();
32
33 // BasePSClient pure virtual implementations
34 int GetParameter(const base::ConstArray<uint64_t>& keys,
35 float* values) override;
36
37 int AsyncGetParameter(const base::ConstArray<uint64_t>& keys,
38 float* values) override;
39
40 int PutParameter(const base::ConstArray<uint64_t>& keys,
41 const std::vector<std::vector<float>>& values) override;
42
43 void Command(PSCommand command) override;
44
45 // Prefetch API (stubbed for distributed client)
46 uint64_t PrefetchParameter(const base::ConstArray<uint64_t>& keys) override;
47 bool IsPrefetchDone(uint64_t prefetch_id) override;
48 void WaitForPrefetch(uint64_t prefetch_id) override;
49 bool GetPrefetchResult(uint64_t prefetch_id,
50 std::vector<std::vector<float>>* values) override;
51 bool GetPrefetchResultFlat(uint64_t prefetch_id,
52 std::vector<float>* values,
53 int64_t* num_rows,
54 int64_t embedding_dim) override;
55
56 int UpdateParameter(const std::string& table_name,
57 const base::ConstArray<uint64_t>& keys,
58 const std::vector<std::vector<float>>* grads) override;
59 int UpdateParameterFlat(const std::string& table_name,
60 const base::ConstArray<uint64_t>& keys,
61 const float* grads,
62 int64_t num_rows,
63 int64_t embedding_dim) override;
64
65 int InitEmbeddingTable(const std::string& table_name,
66 const recstore::EmbeddingTableConfig& config) override;
67
68 // Extended API
69 bool GetParameter(const base::ConstArray<uint64_t>& keys,
70 std::vector<std::vector<float>>* values);
71
72 bool ClearPS();
73
74 // Broadcast to every shard (same as underlying single-shard semantics).
75 bool LoadFakeData(int64_t n);
76 bool DumpFakeData(int64_t n);
77
78 bool LoadCkpt(const std::vector<std::string>& model_config_path,
79 const std::vector<std::string>& emb_file_path);
80
81 4 int shard_count() const { return num_shards_; }
82
83 private:
84 struct DistPrefetchShardState {
85 int shard_id = -1;
86 int client_index = -1;
87 std::vector<size_t> original_indices;
88 std::vector<uint64_t> child_prefetch_ids;
89 std::vector<int> chunk_sizes;
90 };
91
92 struct DistPrefetchState {
93 size_t total_keys = 0;
94 std::vector<DistPrefetchShardState> shard_states;
95 };
96
97 int GetShardId(uint64_t key) const;
98
99 void InitializeClients();
100
101 void
102 PartitionKeys(const base::ConstArray<uint64_t>& keys,
103 std::vector<std::vector<uint64_t>>& partitioned_keys) const;
104
105 void MergeResults(
106 const base::ConstArray<uint64_t>& keys,
107 const std::vector<std::vector<std::vector<float>>>& partitioned_results,
108 std::vector<std::vector<float>>* values) const;
109
110 void MergeResultsToArray(
111 const base::ConstArray<uint64_t>& keys,
112 const std::vector<std::vector<std::vector<float>>>& partitioned_results,
113 float* values) const;
114
115 private:
116 // Config
117 int num_shards_;
118 int max_keys_per_request_;
119 std::string hash_method_;
120
121 // Per-server entries
122 struct ServerConfig {
123 std::string host;
124 int port;
125 int shard;
126 };
127 std::vector<ServerConfig> server_configs_;
128
129 // gRPC clients (one per server entry)
130 std::vector<std::unique_ptr<GRPCParameterClient>> clients_;
131
132 // Logical shard id -> index in clients_
133 std::unordered_map<int, int> shard_to_client_index_;
134
135 // Partition buffers (reused)
136 mutable std::vector<std::vector<uint64_t>> partitioned_key_buffer_;
137 mutable std::vector<std::vector<size_t>> key_index_mapping_;
138
139 std::mutex prefetch_mu_;
140 std::unordered_map<uint64_t, std::shared_ptr<DistPrefetchState>>
141 prefetch_states_;
142 uint64_t next_prefetch_id_ = 1;
143 };
144
145 } // namespace recstore
146