ps/brpc/dist_brpc_ps_client.h
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #pragma once | ||
| 2 | |||
| 3 | #include <future> | ||
| 4 | #include <memory> | ||
| 5 | #include <mutex> | ||
| 6 | #include <string> | ||
| 7 | #include <unordered_map> | ||
| 8 | #include <vector> | ||
| 9 | |||
| 10 | #include "base/array.h" | ||
| 11 | #include "base/hash.h" | ||
| 12 | #include "base/json.h" | ||
| 13 | #include "base/log.h" | ||
| 14 | #include "ps/base/base_client.h" | ||
| 15 | #include "brpc_ps_client.h" | ||
| 16 | |||
| 17 | using json = nlohmann::json; | ||
| 18 | |||
| 19 | namespace recstore { | ||
| 20 | |||
| 21 | /** | ||
| 22 | * @brief Distributed bRPC parameter-server client | ||
| 23 | * | ||
| 24 | * Many-to-many connections; routes keys to shards via a hash function. | ||
| 25 | * Server list and hash method come from JSON config. | ||
| 26 | */ | ||
| 27 | class DistributedBRPCParameterClient : public BasePSClient { | ||
| 28 | public: | ||
| 29 | explicit DistributedBRPCParameterClient(json config); | ||
| 30 | |||
| 31 | ~DistributedBRPCParameterClient(); | ||
| 32 | |||
| 33 | // BasePSClient pure virtual implementations | ||
| 34 | int GetParameter(const base::ConstArray<uint64_t>& keys, | ||
| 35 | float* values) override; | ||
| 36 | |||
| 37 | int AsyncGetParameter(const base::ConstArray<uint64_t>& keys, | ||
| 38 | float* values) override; | ||
| 39 | |||
| 40 | int PutParameter(const base::ConstArray<uint64_t>& keys, | ||
| 41 | const std::vector<std::vector<float>>& values) override; | ||
| 42 | |||
| 43 | void Command(PSCommand command) override; | ||
| 44 | |||
| 45 | int UpdateParameter(const std::string& table_name, | ||
| 46 | const base::ConstArray<uint64_t>& keys, | ||
| 47 | const std::vector<std::vector<float>>* grads) override; | ||
| 48 | int UpdateParameterFlat(const std::string& table_name, | ||
| 49 | const base::ConstArray<uint64_t>& keys, | ||
| 50 | const float* grads, | ||
| 51 | int64_t num_rows, | ||
| 52 | int64_t embedding_dim) override; | ||
| 53 | |||
| 54 | int InitEmbeddingTable(const std::string& table_name, | ||
| 55 | const recstore::EmbeddingTableConfig& config) override; | ||
| 56 | |||
| 57 | // Prefetch API (stubbed for distributed client) | ||
| 58 | uint64_t PrefetchParameter(const base::ConstArray<uint64_t>& keys) override; | ||
| 59 | bool IsPrefetchDone(uint64_t prefetch_id) override; | ||
| 60 | void WaitForPrefetch(uint64_t prefetch_id) override; | ||
| 61 | bool GetPrefetchResult(uint64_t prefetch_id, | ||
| 62 | std::vector<std::vector<float>>* values) override; | ||
| 63 | bool GetPrefetchResultFlat(uint64_t prefetch_id, | ||
| 64 | std::vector<float>* values, | ||
| 65 | int64_t* num_rows, | ||
| 66 | int64_t embedding_dim) override; | ||
| 67 | |||
| 68 | // Extended API | ||
| 69 | bool GetParameter(const base::ConstArray<uint64_t>& keys, | ||
| 70 | std::vector<std::vector<float>>* values); | ||
| 71 | |||
| 72 | bool ClearPS(); | ||
| 73 | |||
| 74 | bool LoadFakeData(int64_t n); | ||
| 75 | bool DumpFakeData(int64_t n); | ||
| 76 | bool LoadCkpt(const std::vector<std::string>& model_config_path, | ||
| 77 | const std::vector<std::string>& emb_file_path); | ||
| 78 | |||
| 79 | 4 | int shard_count() const { return num_shards_; } | |
| 80 | |||
| 81 | private: | ||
| 82 | struct DistPrefetchShardState { | ||
| 83 | int shard_id = -1; | ||
| 84 | int client_index = -1; | ||
| 85 | std::vector<size_t> original_indices; | ||
| 86 | std::vector<uint64_t> child_prefetch_ids; | ||
| 87 | std::vector<int> chunk_sizes; | ||
| 88 | }; | ||
| 89 | |||
| 90 | struct DistPrefetchState { | ||
| 91 | size_t total_keys = 0; | ||
| 92 | std::vector<DistPrefetchShardState> shard_states; | ||
| 93 | }; | ||
| 94 | |||
| 95 | int GetShardId(uint64_t key) const; | ||
| 96 | |||
| 97 | void InitializeClients(); | ||
| 98 | |||
| 99 | void | ||
| 100 | PartitionKeys(const base::ConstArray<uint64_t>& keys, | ||
| 101 | std::vector<std::vector<uint64_t>>& partitioned_keys) const; | ||
| 102 | |||
| 103 | void MergeResults( | ||
| 104 | const base::ConstArray<uint64_t>& keys, | ||
| 105 | const std::vector<std::vector<std::vector<float>>>& partitioned_results, | ||
| 106 | std::vector<std::vector<float>>* values) const; | ||
| 107 | |||
| 108 | void MergeResultsToArray( | ||
| 109 | const base::ConstArray<uint64_t>& keys, | ||
| 110 | const std::vector<std::vector<std::vector<float>>>& partitioned_results, | ||
| 111 | float* values) const; | ||
| 112 | |||
| 113 | private: | ||
| 114 | // Config | ||
| 115 | int num_shards_; | ||
| 116 | int max_keys_per_request_; | ||
| 117 | std::string hash_method_; | ||
| 118 | |||
| 119 | // Per-server entries | ||
| 120 | struct ServerConfig { | ||
| 121 | std::string host; | ||
| 122 | int port; | ||
| 123 | int shard; | ||
| 124 | }; | ||
| 125 | std::vector<ServerConfig> server_configs_; | ||
| 126 | |||
| 127 | // bRPC clients (one per server entry) | ||
| 128 | std::vector<std::unique_ptr<BRPCParameterClient>> clients_; | ||
| 129 | |||
| 130 | // Logical shard id -> index in clients_ | ||
| 131 | std::unordered_map<int, int> shard_to_client_index_; | ||
| 132 | |||
| 133 | // Partition buffers (reused) | ||
| 134 | mutable std::vector<std::vector<uint64_t>> partitioned_key_buffer_; | ||
| 135 | mutable std::vector<std::vector<size_t>> key_index_mapping_; | ||
| 136 | |||
| 137 | std::mutex prefetch_mu_; | ||
| 138 | std::unordered_map<uint64_t, std::shared_ptr<DistPrefetchState>> | ||
| 139 | prefetch_states_; | ||
| 140 | uint64_t next_prefetch_id_ = 1; | ||
| 141 | }; | ||
| 142 | |||
| 143 | } // namespace recstore | ||
| 144 |