ps/grpc/dist_grpc_ps_client.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #include "dist_grpc_ps_client.h" | ||
| 2 | |||
| 3 | #include <algorithm> | ||
| 4 | #include <cstring> | ||
| 5 | #include <future> | ||
| 6 | #include <thread> | ||
| 7 | #include <fstream> | ||
| 8 | |||
| 9 | #include "base/factory.h" | ||
| 10 | #include "base/log.h" | ||
| 11 | #include "base/timer.h" | ||
| 12 | |||
| 13 | using grpc::Channel; | ||
| 14 | using grpc::ClientAsyncResponseReader; | ||
| 15 | using grpc::ClientContext; | ||
| 16 | using grpc::Status; | ||
| 17 | using recstoreps::CommandRequest; | ||
| 18 | using recstoreps::CommandResponse; | ||
| 19 | using recstoreps::GetParameterRequest; | ||
| 20 | using recstoreps::GetParameterResponse; | ||
| 21 | using recstoreps::InitEmbeddingTableRequest; | ||
| 22 | using recstoreps::InitEmbeddingTableResponse; | ||
| 23 | using recstoreps::PSCommand; | ||
| 24 | using recstoreps::PutParameterRequest; | ||
| 25 | using recstoreps::PutParameterResponse; | ||
| 26 | using recstoreps::UpdateParameterRequest; | ||
| 27 | using recstoreps::UpdateParameterResponse; | ||
| 28 | |||
| 29 | namespace recstore { | ||
| 30 | |||
| 31 | FACTORY_REGISTER( | ||
| 32 | BasePSClient, distributed_grpc, DistributedGRPCParameterClient, json); | ||
| 33 | |||
| 34 | 18 | DistributedGRPCParameterClient::DistributedGRPCParameterClient(json config) | |
| 35 |
1/2✓ Branch 2 taken 18 times.
✗ Branch 3 not taken.
|
18 | : BasePSClient(config) { |
| 36 | 18 | json client_config; | |
| 37 | |||
| 38 |
2/4✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 18 times.
✗ Branch 4 not taken.
|
18 | if (config.contains("distributed_client")) { |
| 39 |
2/4✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 5 not taken.
|
36 | LOG(INFO) << "Detected recstore config format, extracting " |
| 40 |
1/2✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
|
18 | "distributed_client section"; |
| 41 |
2/4✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 5 not taken.
|
18 | client_config = config["distributed_client"]; |
| 42 | } else { | ||
| 43 | ✗ | LOG(FATAL) | |
| 44 | << "Invalid config format. Expected either recstore config with " | ||
| 45 | "'distributed_client' section " | ||
| 46 | ✗ | << "or direct client config with 'servers' and 'num_shards' fields"; | |
| 47 | } | ||
| 48 | |||
| 49 | // 验证必要字段 | ||
| 50 |
3/6✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 18 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 18 times.
|
36 | if (!client_config.contains("servers") || |
| 51 |
2/4✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 18 times.
|
18 | !client_config["servers"].is_array()) { |
| 52 | ✗ | LOG(FATAL) | |
| 53 | ✗ | << "Missing or invalid 'servers' field in distributed client config"; | |
| 54 | } | ||
| 55 | |||
| 56 |
3/6✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 18 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 18 times.
|
36 | if (!client_config.contains("num_shards") || |
| 57 |
2/4✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 18 times.
|
18 | !client_config["num_shards"].is_number_integer()) { |
| 58 | ✗ | LOG(FATAL) | |
| 59 | ✗ | << "Missing or invalid 'num_shards' field in distributed client config"; | |
| 60 | } | ||
| 61 | |||
| 62 |
2/4✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 5 not taken.
|
18 | num_shards_ = client_config["num_shards"].get<int>(); |
| 63 |
1/2✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
|
18 | max_keys_per_request_ = client_config.value("max_keys_per_request", 500); |
| 64 |
1/2✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
|
18 | hash_method_ = client_config.value("hash_method", "city_hash"); |
| 65 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
|
18 | if (max_keys_per_request_ <= 0) { |
| 66 | ✗ | LOG(FATAL) << "Invalid max_keys_per_request: " << max_keys_per_request_ | |
| 67 | ✗ | << ", must be > 0"; | |
| 68 | } | ||
| 69 | |||
| 70 | // 解析服务器配置 | ||
| 71 |
2/4✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 5 not taken.
|
18 | auto servers = client_config["servers"]; |
| 72 |
1/2✓ Branch 2 taken 18 times.
✗ Branch 3 not taken.
|
18 | server_configs_.reserve(servers.size()); |
| 73 | |||
| 74 |
2/2✓ Branch 1 taken 36 times.
✓ Branch 2 taken 18 times.
|
54 | for (size_t i = 0; i < servers.size(); ++i) { |
| 75 |
1/2✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
|
36 | const auto& server = servers[i]; |
| 76 |
5/10✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 36 times.
✗ Branch 4 not taken.
✓ Branch 6 taken 36 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 36 times.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✓ Branch 11 taken 36 times.
|
72 | if (!server.contains("host") || !server.contains("port") || |
| 77 |
2/4✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 36 times.
|
36 | !server.contains("shard")) { |
| 78 | ✗ | LOG(FATAL) << "Server config " << i | |
| 79 | ✗ | << " missing required fields (host, port, shard)"; | |
| 80 | } | ||
| 81 | |||
| 82 | 36 | ServerConfig cfg; | |
| 83 |
2/4✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 36 times.
✗ Branch 5 not taken.
|
36 | cfg.host = server["host"].get<std::string>(); |
| 84 |
2/4✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 36 times.
✗ Branch 5 not taken.
|
36 | cfg.port = server["port"].get<int>(); |
| 85 |
2/4✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 36 times.
✗ Branch 5 not taken.
|
36 | cfg.shard = server["shard"].get<int>(); |
| 86 | |||
| 87 |
1/2✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
|
36 | server_configs_.push_back(cfg); |
| 88 |
1/2✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
|
36 | shard_to_client_index_[cfg.shard] = i; |
| 89 | 36 | } | |
| 90 | |||
| 91 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 18 times.
|
18 | if (server_configs_.size() != static_cast<size_t>(num_shards_)) { |
| 92 | ✗ | LOG(WARNING) << "Number of servers (" << server_configs_.size() | |
| 93 | ✗ | << ") doesn't match num_shards (" << num_shards_ << ")"; | |
| 94 | } | ||
| 95 | |||
| 96 |
1/2✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
|
18 | partitioned_key_buffer_.resize(num_shards_); |
| 97 |
1/2✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
|
18 | key_index_mapping_.resize(num_shards_); |
| 98 | |||
| 99 |
1/2✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
|
18 | InitializeClients(); |
| 100 | |||
| 101 |
3/6✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 18 times.
✗ Branch 8 not taken.
|
36 | LOG(INFO) << "Initialized DistributedGRPCParameterClient with " << num_shards_ |
| 102 |
3/6✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 18 times.
✗ Branch 8 not taken.
|
18 | << " shards, hash method: " << hash_method_; |
| 103 | 18 | } | |
| 104 | |||
| 105 | 26 | DistributedGRPCParameterClient::~DistributedGRPCParameterClient() {} | |
| 106 | |||
| 107 | 18 | void DistributedGRPCParameterClient::InitializeClients() { | |
| 108 | 18 | clients_.clear(); | |
| 109 | 18 | clients_.reserve(server_configs_.size()); | |
| 110 | |||
| 111 |
2/2✓ Branch 5 taken 36 times.
✓ Branch 6 taken 18 times.
|
54 | for (const auto& server_config : server_configs_) { |
| 112 | // 为每个服务器创建独立的配置 | ||
| 113 | 36 | json client_config = {{"host", server_config.host}, | |
| 114 | 36 | {"port", server_config.port}, | |
| 115 |
8/16✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 36 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 36 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 36 times.
✗ Branch 11 not taken.
✓ Branch 14 taken 36 times.
✗ Branch 15 not taken.
✓ Branch 17 taken 36 times.
✗ Branch 18 not taken.
✓ Branch 21 taken 36 times.
✗ Branch 22 not taken.
✓ Branch 24 taken 36 times.
✗ Branch 25 not taken.
|
468 | {"shard", server_config.shard}}; |
| 116 | |||
| 117 |
3/6✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 36 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 36 times.
✗ Branch 8 not taken.
|
36 | auto* raw_client = new GRPCParameterClient(client_config); |
| 118 | 36 | auto client = std::unique_ptr<GRPCParameterClient>(raw_client); | |
| 119 |
1/2✓ Branch 2 taken 36 times.
✗ Branch 3 not taken.
|
36 | clients_.push_back(std::move(client)); |
| 120 | |||
| 121 |
3/6✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 36 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 36 times.
✗ Branch 8 not taken.
|
72 | LOG(INFO) << "Created gRPC client for shard " << server_config.shard |
| 122 |
5/10✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 36 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 36 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 36 times.
✗ Branch 11 not taken.
✓ Branch 13 taken 36 times.
✗ Branch 14 not taken.
|
36 | << " at " << server_config.host << ":" << server_config.port; |
| 123 | 36 | } | |
| 124 | 18 | } | |
| 125 | |||
| 126 | 2040 | int DistributedGRPCParameterClient::GetShardId(uint64_t key) const { | |
| 127 |
1/2✓ Branch 1 taken 2040 times.
✗ Branch 2 not taken.
|
2040 | if (hash_method_ == "city_hash") { |
| 128 | 2040 | return GetHash(key) % num_shards_; | |
| 129 | ✗ | } else if (hash_method_ == "simple_mod") { | |
| 130 | ✗ | return key % num_shards_; | |
| 131 | } else { | ||
| 132 | ✗ | LOG(ERROR) << "Unknown hash method: " << hash_method_ | |
| 133 | ✗ | << ", using city_hash"; | |
| 134 | ✗ | return GetHash(key) % num_shards_; | |
| 135 | } | ||
| 136 | } | ||
| 137 | |||
| 138 | 106 | void DistributedGRPCParameterClient::PartitionKeys( | |
| 139 | const base::ConstArray<uint64_t>& keys, | ||
| 140 | std::vector<std::vector<uint64_t>>& partitioned_keys) const { | ||
| 141 |
2/2✓ Branch 5 taken 212 times.
✓ Branch 6 taken 106 times.
|
318 | for (auto& partition : partitioned_key_buffer_) { |
| 142 | 212 | partition.clear(); | |
| 143 | } | ||
| 144 |
2/2✓ Branch 5 taken 212 times.
✓ Branch 6 taken 106 times.
|
318 | for (auto& mapping : key_index_mapping_) { |
| 145 | 212 | mapping.clear(); | |
| 146 | } | ||
| 147 | |||
| 148 |
2/2✓ Branch 1 taken 1616 times.
✓ Branch 2 taken 106 times.
|
1722 | for (size_t i = 0; i < keys.Size(); ++i) { |
| 149 | 1616 | uint64_t key = keys[i]; | |
| 150 |
1/2✓ Branch 1 taken 1616 times.
✗ Branch 2 not taken.
|
1616 | int shard_id = GetShardId(key); |
| 151 | |||
| 152 |
1/2✓ Branch 2 taken 1616 times.
✗ Branch 3 not taken.
|
1616 | partitioned_key_buffer_[shard_id].push_back(key); |
| 153 |
1/2✓ Branch 2 taken 1616 times.
✗ Branch 3 not taken.
|
1616 | key_index_mapping_[shard_id].push_back(i); |
| 154 | } | ||
| 155 | |||
| 156 | 106 | partitioned_keys = partitioned_key_buffer_; | |
| 157 | 106 | } | |
| 158 | |||
| 159 | 46 | bool DistributedGRPCParameterClient::GetParameter( | |
| 160 | const base::ConstArray<uint64_t>& keys, | ||
| 161 | std::vector<std::vector<float>>* values) { | ||
| 162 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 46 times.
|
46 | if (keys.Size() == 0) { |
| 163 | ✗ | values->clear(); | |
| 164 | ✗ | return true; | |
| 165 | } | ||
| 166 | |||
| 167 |
2/4✓ Branch 2 taken 46 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 46 times.
✗ Branch 6 not taken.
|
92 | xmh::Timer timer("DistributedGRPCParameterClient::GetParameter"); |
| 168 | |||
| 169 | 46 | std::vector<std::vector<uint64_t>> partitioned_keys; | |
| 170 |
1/2✓ Branch 1 taken 46 times.
✗ Branch 2 not taken.
|
46 | PartitionKeys(keys, partitioned_keys); |
| 171 | |||
| 172 | 46 | std::vector<std::future<int>> futures; | |
| 173 |
1/2✓ Branch 2 taken 46 times.
✗ Branch 3 not taken.
|
46 | std::vector<std::vector<std::vector<float>>> partitioned_results(num_shards_); |
| 174 | |||
| 175 |
2/2✓ Branch 0 taken 92 times.
✓ Branch 1 taken 46 times.
|
138 | for (int shard_id = 0; shard_id < num_shards_; ++shard_id) { |
| 176 |
2/2✓ Branch 2 taken 8 times.
✓ Branch 3 taken 84 times.
|
92 | if (partitioned_keys[shard_id].empty()) { |
| 177 | 8 | continue; | |
| 178 | } | ||
| 179 | |||
| 180 |
1/2✓ Branch 1 taken 84 times.
✗ Branch 2 not taken.
|
84 | auto it = shard_to_client_index_.find(shard_id); |
| 181 |
1/2✗ Branch 2 not taken.
✓ Branch 3 taken 84 times.
|
84 | if (it == shard_to_client_index_.end()) { |
| 182 | ✗ | LOG(ERROR) << "No client found for shard " << shard_id; | |
| 183 | ✗ | return false; | |
| 184 | } | ||
| 185 | |||
| 186 | 84 | int client_index = it->second; | |
| 187 | 84 | auto* client = clients_[client_index].get(); | |
| 188 | |||
| 189 | // 异步请求 | ||
| 190 |
2/4✓ Branch 1 taken 84 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
|
84 | futures.push_back(std::async( |
| 191 | 168 | std::launch::async, [=, &partitioned_keys, &partitioned_results]() { | |
| 192 | 84 | const auto& shard_keys_vec = partitioned_keys[shard_id]; | |
| 193 | 84 | auto& shard_result_vec = partitioned_results[shard_id]; | |
| 194 | 84 | shard_result_vec.clear(); | |
| 195 | 84 | shard_result_vec.reserve(shard_keys_vec.size()); | |
| 196 | |||
| 197 |
2/2✓ Branch 1 taken 88 times.
✓ Branch 2 taken 84 times.
|
172 | for (size_t start = 0; start < shard_keys_vec.size(); |
| 198 | 88 | start += static_cast<size_t>(max_keys_per_request_)) { | |
| 199 | size_t end = | ||
| 200 | 176 | std::min(start + static_cast<size_t>(max_keys_per_request_), | |
| 201 | 88 | shard_keys_vec.size()); | |
| 202 | base::ConstArray<uint64_t> shard_chunk( | ||
| 203 | 88 | shard_keys_vec.data() + start, static_cast<int>(end - start)); | |
| 204 | 88 | std::vector<std::vector<float>> chunk_result; | |
| 205 |
2/4✓ Branch 1 taken 88 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 88 times.
|
88 | if (!client->GetParameter(shard_chunk, &chunk_result)) { |
| 206 | ✗ | return 0; | |
| 207 | } | ||
| 208 |
1/2✓ Branch 5 taken 88 times.
✗ Branch 6 not taken.
|
88 | shard_result_vec.insert(shard_result_vec.end(), |
| 209 | chunk_result.begin(), | ||
| 210 | chunk_result.end()); | ||
| 211 |
1/2✓ Branch 1 taken 88 times.
✗ Branch 2 not taken.
|
88 | } |
| 212 | 84 | return 1; | |
| 213 | })); | ||
| 214 | } | ||
| 215 | |||
| 216 | // 3. 等待所有请求完成 | ||
| 217 |
2/2✓ Branch 5 taken 84 times.
✓ Branch 6 taken 46 times.
|
130 | for (auto& future : futures) { |
| 218 |
2/4✓ Branch 1 taken 84 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
|
84 | if (!future.get()) { |
| 219 | ✗ | LOG(ERROR) << "Failed to get parameters from one of the shards"; | |
| 220 | ✗ | return false; | |
| 221 | } | ||
| 222 | } | ||
| 223 | |||
| 224 | // 4. 合并结果 | ||
| 225 |
1/2✓ Branch 1 taken 46 times.
✗ Branch 2 not taken.
|
46 | MergeResults(keys, partitioned_results, values); |
| 226 | |||
| 227 | 46 | return true; | |
| 228 | 46 | } | |
| 229 | |||
| 230 | 46 | void DistributedGRPCParameterClient::MergeResults( | |
| 231 | const base::ConstArray<uint64_t>& keys, | ||
| 232 | const std::vector<std::vector<std::vector<float>>>& partitioned_results, | ||
| 233 | std::vector<std::vector<float>>* values) const { | ||
| 234 | 46 | values->clear(); | |
| 235 |
1/2✓ Branch 2 taken 46 times.
✗ Branch 3 not taken.
|
46 | values->resize(keys.Size()); |
| 236 | |||
| 237 | // 重建key -> index映射 | ||
| 238 | 46 | std::unordered_map<uint64_t, size_t> key_to_result_index; | |
| 239 |
2/2✓ Branch 0 taken 92 times.
✓ Branch 1 taken 46 times.
|
138 | for (int shard_id = 0; shard_id < num_shards_; ++shard_id) { |
| 240 |
2/2✓ Branch 2 taken 368 times.
✓ Branch 3 taken 92 times.
|
460 | for (size_t i = 0; i < key_index_mapping_[shard_id].size(); ++i) { |
| 241 | 368 | size_t original_index = key_index_mapping_[shard_id][i]; | |
| 242 |
1/2✓ Branch 2 taken 368 times.
✗ Branch 3 not taken.
|
368 | if (i < partitioned_results[shard_id].size()) { |
| 243 |
1/2✓ Branch 4 taken 368 times.
✗ Branch 5 not taken.
|
368 | (*values)[original_index] = partitioned_results[shard_id][i]; |
| 244 | } | ||
| 245 | } | ||
| 246 | } | ||
| 247 | 46 | } | |
| 248 | |||
| 249 | ✗ | void DistributedGRPCParameterClient::MergeResultsToArray( | |
| 250 | const base::ConstArray<uint64_t>& keys, | ||
| 251 | const std::vector<std::vector<std::vector<float>>>& partitioned_results, | ||
| 252 | float* values) const { | ||
| 253 | ✗ | int emb_dim = 0; | |
| 254 | ✗ | for (int shard_id = 0; shard_id < num_shards_; ++shard_id) { | |
| 255 | ✗ | if (!partitioned_results[shard_id].empty() && | |
| 256 | ✗ | !partitioned_results[shard_id][0].empty()) { | |
| 257 | ✗ | emb_dim = partitioned_results[shard_id][0].size(); | |
| 258 | ✗ | break; | |
| 259 | } | ||
| 260 | } | ||
| 261 | |||
| 262 | ✗ | if (emb_dim == 0) { | |
| 263 | ✗ | LOG(WARNING) << "No valid embeddings found"; | |
| 264 | ✗ | return; | |
| 265 | } | ||
| 266 | |||
| 267 | // 合并结果到连续内存 | ||
| 268 | ✗ | for (int shard_id = 0; shard_id < num_shards_; ++shard_id) { | |
| 269 | ✗ | for (size_t i = 0; i < key_index_mapping_[shard_id].size(); ++i) { | |
| 270 | ✗ | size_t original_index = key_index_mapping_[shard_id][i]; | |
| 271 | ✗ | if (i < partitioned_results[shard_id].size()) { | |
| 272 | ✗ | const auto& embedding = partitioned_results[shard_id][i]; | |
| 273 | ✗ | std::copy(embedding.begin(), | |
| 274 | embedding.end(), | ||
| 275 | ✗ | values + original_index * emb_dim); | |
| 276 | } | ||
| 277 | } | ||
| 278 | } | ||
| 279 | } | ||
| 280 | |||
| 281 | // 实现BasePSClient接口 | ||
| 282 | 32 | int DistributedGRPCParameterClient::GetParameter( | |
| 283 | const base::ConstArray<uint64_t>& keys, float* values) { | ||
| 284 | 32 | std::vector<std::vector<float>> result_vectors; | |
| 285 |
1/2✓ Branch 1 taken 32 times.
✗ Branch 2 not taken.
|
32 | bool success = GetParameter(keys, &result_vectors); |
| 286 | |||
| 287 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 32 times.
|
32 | if (!success) { |
| 288 | ✗ | return -1; | |
| 289 | } | ||
| 290 | |||
| 291 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 32 times.
|
32 | if (keys.Size() == 0) { |
| 292 | ✗ | return 0; | |
| 293 | } | ||
| 294 | 32 | int emb_dim = 0; | |
| 295 |
1/2✓ Branch 5 taken 32 times.
✗ Branch 6 not taken.
|
32 | for (const auto& row : result_vectors) { |
| 296 |
1/2✓ Branch 1 taken 32 times.
✗ Branch 2 not taken.
|
32 | if (!row.empty()) { |
| 297 | 32 | emb_dim = static_cast<int>(row.size()); | |
| 298 | 32 | break; | |
| 299 | } | ||
| 300 | } | ||
| 301 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 32 times.
|
32 | if (emb_dim == 0) { |
| 302 | ✗ | LOG(WARNING) << "No valid embeddings found"; | |
| 303 | ✗ | return 0; | |
| 304 | } | ||
| 305 | |||
| 306 |
2/2✓ Branch 1 taken 92 times.
✓ Branch 2 taken 32 times.
|
124 | for (size_t i = 0; i < result_vectors.size(); ++i) { |
| 307 | 92 | const auto& row = result_vectors[i]; | |
| 308 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 92 times.
|
92 | if (row.empty()) { |
| 309 | ✗ | continue; | |
| 310 | } | ||
| 311 |
1/2✓ Branch 3 taken 92 times.
✗ Branch 4 not taken.
|
92 | std::copy(row.begin(), row.end(), values + i * emb_dim); |
| 312 | } | ||
| 313 | 32 | return 0; | |
| 314 | 32 | } | |
| 315 | |||
| 316 | ✗ | int DistributedGRPCParameterClient::AsyncGetParameter( | |
| 317 | const base::ConstArray<uint64_t>& keys, float* values) { | ||
| 318 | ✗ | return GetParameter(keys, values); | |
| 319 | } | ||
| 320 | |||
| 321 | 58 | int DistributedGRPCParameterClient::PutParameter( | |
| 322 | const base::ConstArray<uint64_t>& keys, | ||
| 323 | const std::vector<std::vector<float>>& values) { | ||
| 324 |
1/2✗ Branch 2 not taken.
✓ Branch 3 taken 58 times.
|
58 | if (keys.Size() != values.size()) { |
| 325 | ✗ | LOG(ERROR) << "Keys and values size mismatch: " << keys.Size() << " vs " | |
| 326 | ✗ | << values.size(); | |
| 327 | ✗ | return -1; | |
| 328 | } | ||
| 329 | |||
| 330 | 58 | std::vector<std::vector<uint64_t>> partitioned_keys; | |
| 331 |
1/2✓ Branch 1 taken 58 times.
✗ Branch 2 not taken.
|
58 | PartitionKeys(keys, partitioned_keys); |
| 332 | |||
| 333 |
1/2✓ Branch 2 taken 58 times.
✗ Branch 3 not taken.
|
58 | std::vector<std::vector<std::vector<float>>> partitioned_values(num_shards_); |
| 334 |
2/2✓ Branch 0 taken 116 times.
✓ Branch 1 taken 58 times.
|
174 | for (int shard_id = 0; shard_id < num_shards_; ++shard_id) { |
| 335 |
2/2✓ Branch 2 taken 1244 times.
✓ Branch 3 taken 116 times.
|
1360 | for (size_t i = 0; i < key_index_mapping_[shard_id].size(); ++i) { |
| 336 | 1244 | size_t original_index = key_index_mapping_[shard_id][i]; | |
| 337 |
1/2✓ Branch 3 taken 1244 times.
✗ Branch 4 not taken.
|
1244 | partitioned_values[shard_id].push_back(values[original_index]); |
| 338 | } | ||
| 339 | } | ||
| 340 | |||
| 341 | // 并发put到各个分片 | ||
| 342 | 58 | std::vector<std::future<int>> futures; | |
| 343 | |||
| 344 |
2/2✓ Branch 0 taken 116 times.
✓ Branch 1 taken 58 times.
|
174 | for (int shard_id = 0; shard_id < num_shards_; ++shard_id) { |
| 345 |
1/2✗ Branch 2 not taken.
✓ Branch 3 taken 116 times.
|
116 | if (partitioned_keys[shard_id].empty()) { |
| 346 | ✗ | continue; | |
| 347 | } | ||
| 348 | |||
| 349 |
1/2✓ Branch 1 taken 116 times.
✗ Branch 2 not taken.
|
116 | auto it = shard_to_client_index_.find(shard_id); |
| 350 |
1/2✗ Branch 2 not taken.
✓ Branch 3 taken 116 times.
|
116 | if (it == shard_to_client_index_.end()) { |
| 351 | ✗ | LOG(ERROR) << "No client found for shard " << shard_id; | |
| 352 | ✗ | return -1; | |
| 353 | } | ||
| 354 | |||
| 355 | 116 | int client_index = it->second; | |
| 356 | 116 | auto* client = clients_[client_index].get(); | |
| 357 | |||
| 358 |
2/4✓ Branch 1 taken 116 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 116 times.
✗ Branch 5 not taken.
|
116 | futures.push_back(std::async( |
| 359 | 232 | std::launch::async, [=, &partitioned_keys, &partitioned_values]() { | |
| 360 | 116 | const auto& shard_keys_vec = partitioned_keys[shard_id]; | |
| 361 | 116 | const auto& shard_vals_vec = partitioned_values[shard_id]; | |
| 362 |
2/2✓ Branch 1 taken 148 times.
✓ Branch 2 taken 116 times.
|
264 | for (size_t start = 0; start < shard_keys_vec.size(); |
| 363 | 148 | start += static_cast<size_t>(max_keys_per_request_)) { | |
| 364 | size_t end = | ||
| 365 | 296 | std::min(start + static_cast<size_t>(max_keys_per_request_), | |
| 366 | 148 | shard_keys_vec.size()); | |
| 367 | base::ConstArray<uint64_t> shard_chunk( | ||
| 368 | 148 | shard_keys_vec.data() + start, static_cast<int>(end - start)); | |
| 369 | std::vector<std::vector<float>> value_chunk( | ||
| 370 |
1/2✓ Branch 6 taken 148 times.
✗ Branch 7 not taken.
|
148 | shard_vals_vec.begin() + start, shard_vals_vec.begin() + end); |
| 371 |
2/4✓ Branch 1 taken 148 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 148 times.
|
148 | if (client->PutParameter(shard_chunk, value_chunk) != 1) { |
| 372 | ✗ | return 0; | |
| 373 | } | ||
| 374 |
1/2✓ Branch 1 taken 148 times.
✗ Branch 2 not taken.
|
148 | } |
| 375 | 116 | return 1; | |
| 376 | })); | ||
| 377 | } | ||
| 378 | |||
| 379 | // 等待所有请求完成 | ||
| 380 |
2/2✓ Branch 5 taken 116 times.
✓ Branch 6 taken 58 times.
|
174 | for (auto& future : futures) { |
| 381 |
2/4✓ Branch 1 taken 116 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 116 times.
|
116 | if (future.get() != 1) { |
| 382 | ✗ | LOG(ERROR) << "Failed to put parameters to one of the shards"; | |
| 383 | ✗ | return -1; | |
| 384 | } | ||
| 385 | } | ||
| 386 | |||
| 387 | 58 | return 0; | |
| 388 | 58 | } | |
| 389 | |||
| 390 | ✗ | void DistributedGRPCParameterClient::Command(PSCommand command) { | |
| 391 | ✗ | std::vector<std::future<void>> futures; | |
| 392 | |||
| 393 | ✗ | for (auto& client : clients_) { | |
| 394 | ✗ | futures.push_back(std::async(std::launch::async, [&client, command]() { | |
| 395 | ✗ | client->Command(command); | |
| 396 | ✗ | })); | |
| 397 | } | ||
| 398 | |||
| 399 | ✗ | for (auto& future : futures) { | |
| 400 | ✗ | future.wait(); | |
| 401 | } | ||
| 402 | ✗ | } | |
| 403 | |||
| 404 | 14 | bool DistributedGRPCParameterClient::ClearPS() { | |
| 405 | 14 | std::vector<std::future<bool>> futures; | |
| 406 | |||
| 407 |
2/2✓ Branch 4 taken 28 times.
✓ Branch 5 taken 14 times.
|
42 | for (auto& client : clients_) { |
| 408 |
2/4✓ Branch 1 taken 28 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 28 times.
✗ Branch 5 not taken.
|
28 | futures.push_back(std::async(std::launch::async, [&client]() { |
| 409 | 28 | return client->ClearPS(); | |
| 410 | })); | ||
| 411 | } | ||
| 412 | |||
| 413 | 14 | bool all_success = true; | |
| 414 |
2/2✓ Branch 5 taken 28 times.
✓ Branch 6 taken 14 times.
|
42 | for (auto& future : futures) { |
| 415 |
2/4✓ Branch 1 taken 28 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 28 times.
|
28 | if (!future.get()) { |
| 416 | ✗ | all_success = false; | |
| 417 | } | ||
| 418 | } | ||
| 419 | |||
| 420 | 14 | return all_success; | |
| 421 | 14 | } | |
| 422 | |||
| 423 | 2 | bool DistributedGRPCParameterClient::LoadFakeData(int64_t n) { | |
| 424 | 2 | std::vector<std::future<bool>> futures; | |
| 425 |
2/2✓ Branch 4 taken 4 times.
✓ Branch 5 taken 2 times.
|
6 | for (auto& client : clients_) { |
| 426 | 4 | GRPCParameterClient* raw = client.get(); | |
| 427 |
2/4✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 5 not taken.
|
4 | futures.push_back(std::async(std::launch::async, [raw, n]() { |
| 428 | 4 | return raw->LoadFakeData(n); | |
| 429 | })); | ||
| 430 | } | ||
| 431 | 2 | bool all_success = true; | |
| 432 |
2/2✓ Branch 5 taken 4 times.
✓ Branch 6 taken 2 times.
|
6 | for (auto& future : futures) { |
| 433 |
2/4✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
|
4 | if (!future.get()) { |
| 434 | ✗ | all_success = false; | |
| 435 | } | ||
| 436 | } | ||
| 437 | 2 | return all_success; | |
| 438 | 2 | } | |
| 439 | |||
| 440 | 2 | bool DistributedGRPCParameterClient::DumpFakeData(int64_t n) { | |
| 441 | 2 | std::vector<std::future<bool>> futures; | |
| 442 |
2/2✓ Branch 4 taken 4 times.
✓ Branch 5 taken 2 times.
|
6 | for (auto& client : clients_) { |
| 443 | 4 | GRPCParameterClient* raw = client.get(); | |
| 444 |
2/4✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 5 not taken.
|
4 | futures.push_back(std::async(std::launch::async, [raw, n]() { |
| 445 | 4 | return raw->DumpFakeData(n); | |
| 446 | })); | ||
| 447 | } | ||
| 448 | 2 | bool all_success = true; | |
| 449 |
2/2✓ Branch 5 taken 4 times.
✓ Branch 6 taken 2 times.
|
6 | for (auto& future : futures) { |
| 450 |
2/4✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
|
4 | if (!future.get()) { |
| 451 | ✗ | all_success = false; | |
| 452 | } | ||
| 453 | } | ||
| 454 | 2 | return all_success; | |
| 455 | 2 | } | |
| 456 | |||
| 457 | ✗ | bool DistributedGRPCParameterClient::LoadCkpt( | |
| 458 | const std::vector<std::string>& model_config_path, | ||
| 459 | const std::vector<std::string>& emb_file_path) { | ||
| 460 | ✗ | std::vector<std::future<bool>> futures; | |
| 461 | |||
| 462 | ✗ | for (auto& client : clients_) { | |
| 463 | ✗ | futures.push_back(std::async( | |
| 464 | ✗ | std::launch::async, [&client, &model_config_path, &emb_file_path]() { | |
| 465 | ✗ | return client->LoadCkpt(model_config_path, emb_file_path); | |
| 466 | })); | ||
| 467 | } | ||
| 468 | |||
| 469 | ✗ | bool all_success = true; | |
| 470 | ✗ | for (auto& future : futures) { | |
| 471 | ✗ | if (!future.get()) { | |
| 472 | ✗ | all_success = false; | |
| 473 | } | ||
| 474 | } | ||
| 475 | |||
| 476 | ✗ | return all_success; | |
| 477 | ✗ | } | |
| 478 | |||
| 479 | 2 | int DistributedGRPCParameterClient::UpdateParameter( | |
| 480 | const std::string& table_name, | ||
| 481 | const base::ConstArray<uint64_t>& keys, | ||
| 482 | const std::vector<std::vector<float>>* grads) { | ||
| 483 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
|
2 | if (grads == nullptr) { |
| 484 | ✗ | LOG(ERROR) << "UpdateParameter grads pointer is null"; | |
| 485 | ✗ | return -1; | |
| 486 | } | ||
| 487 |
1/2✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
|
2 | if (keys.Size() != grads->size()) { |
| 488 | ✗ | LOG(ERROR) << "UpdateParameter keys/grads size mismatch: " << keys.Size() | |
| 489 | ✗ | << " vs " << grads->size(); | |
| 490 | ✗ | return -1; | |
| 491 | } | ||
| 492 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
|
2 | if (keys.Size() == 0) { |
| 493 | ✗ | return 0; | |
| 494 | } | ||
| 495 | |||
| 496 | 2 | std::vector<std::vector<uint64_t>> partitioned_keys; | |
| 497 |
1/2✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
|
2 | PartitionKeys(keys, partitioned_keys); |
| 498 | |||
| 499 |
1/2✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
2 | std::vector<std::vector<std::vector<float>>> partitioned_grads(num_shards_); |
| 500 |
2/2✓ Branch 0 taken 4 times.
✓ Branch 1 taken 2 times.
|
6 | for (int shard_id = 0; shard_id < num_shards_; ++shard_id) { |
| 501 |
2/2✓ Branch 2 taken 4 times.
✓ Branch 3 taken 4 times.
|
8 | for (size_t i = 0; i < key_index_mapping_[shard_id].size(); ++i) { |
| 502 | 4 | size_t original_index = key_index_mapping_[shard_id][i]; | |
| 503 |
1/2✓ Branch 3 taken 4 times.
✗ Branch 4 not taken.
|
4 | partitioned_grads[shard_id].push_back((*grads)[original_index]); |
| 504 | } | ||
| 505 | } | ||
| 506 | |||
| 507 | 2 | std::vector<std::future<int>> futures; | |
| 508 |
2/2✓ Branch 0 taken 4 times.
✓ Branch 1 taken 2 times.
|
6 | for (int shard_id = 0; shard_id < num_shards_; ++shard_id) { |
| 509 |
2/2✓ Branch 2 taken 2 times.
✓ Branch 3 taken 2 times.
|
4 | if (partitioned_keys[shard_id].empty()) { |
| 510 | 2 | continue; | |
| 511 | } | ||
| 512 |
1/2✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
|
2 | auto it = shard_to_client_index_.find(shard_id); |
| 513 |
1/2✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
|
2 | if (it == shard_to_client_index_.end()) { |
| 514 | ✗ | LOG(ERROR) << "No client found for shard " << shard_id; | |
| 515 | ✗ | return -1; | |
| 516 | } | ||
| 517 | 2 | int client_index = it->second; | |
| 518 | 2 | auto* client = clients_[client_index].get(); | |
| 519 | |||
| 520 |
3/6✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 8 not taken.
|
2 | futures.push_back(std::async( |
| 521 | 4 | std::launch::async, [=, &partitioned_keys, &partitioned_grads]() { | |
| 522 | 2 | const auto& shard_keys_vec = partitioned_keys[shard_id]; | |
| 523 | 2 | const auto& shard_grads_vec = partitioned_grads[shard_id]; | |
| 524 |
2/2✓ Branch 1 taken 2 times.
✓ Branch 2 taken 2 times.
|
4 | for (size_t start = 0; start < shard_keys_vec.size(); |
| 525 | 2 | start += static_cast<size_t>(max_keys_per_request_)) { | |
| 526 | size_t end = | ||
| 527 | 4 | std::min(start + static_cast<size_t>(max_keys_per_request_), | |
| 528 | 2 | shard_keys_vec.size()); | |
| 529 | base::ConstArray<uint64_t> shard_chunk( | ||
| 530 | 2 | shard_keys_vec.data() + start, static_cast<int>(end - start)); | |
| 531 | std::vector<std::vector<float>> grad_chunk( | ||
| 532 |
1/2✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
|
2 | shard_grads_vec.begin() + start, shard_grads_vec.begin() + end); |
| 533 |
2/4✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
|
2 | if (client->UpdateParameter(table_name, shard_chunk, &grad_chunk) != |
| 534 | 0) { | ||
| 535 | ✗ | return -1; | |
| 536 | } | ||
| 537 |
1/2✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
|
2 | } |
| 538 | 2 | return 0; | |
| 539 | })); | ||
| 540 | } | ||
| 541 | |||
| 542 |
2/2✓ Branch 5 taken 2 times.
✓ Branch 6 taken 2 times.
|
4 | for (auto& future : futures) { |
| 543 |
2/4✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
|
2 | if (future.get() != 0) { |
| 544 | ✗ | LOG(ERROR) << "Failed to update parameters on one of the shards"; | |
| 545 | ✗ | return -1; | |
| 546 | } | ||
| 547 | } | ||
| 548 | |||
| 549 | 2 | return 0; | |
| 550 | 2 | } | |
| 551 | |||
| 552 | 2 | int DistributedGRPCParameterClient::UpdateParameterFlat( | |
| 553 | const std::string& table_name, | ||
| 554 | const base::ConstArray<uint64_t>& keys, | ||
| 555 | const float* grads, | ||
| 556 | int64_t num_rows, | ||
| 557 | int64_t embedding_dim) { | ||
| 558 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
|
2 | if (grads == nullptr) { |
| 559 | ✗ | LOG(ERROR) << "UpdateParameterFlat grads pointer is null"; | |
| 560 | ✗ | return -1; | |
| 561 | } | ||
| 562 |
2/4✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
|
2 | if (num_rows < 0 || embedding_dim <= 0) { |
| 563 | ✗ | LOG(ERROR) << "UpdateParameterFlat invalid shape: rows=" << num_rows | |
| 564 | ✗ | << " dim=" << embedding_dim; | |
| 565 | ✗ | return -1; | |
| 566 | } | ||
| 567 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
|
2 | if (keys.Size() != static_cast<size_t>(num_rows)) { |
| 568 | ✗ | LOG(ERROR) << "UpdateParameterFlat keys/grads size mismatch: " | |
| 569 | ✗ | << keys.Size() << " vs " << num_rows; | |
| 570 | ✗ | return -1; | |
| 571 | } | ||
| 572 | |||
| 573 | 2 | std::vector<std::vector<float>> row_grads; | |
| 574 |
1/2✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
|
2 | row_grads.reserve(static_cast<size_t>(num_rows)); |
| 575 |
2/2✓ Branch 0 taken 4 times.
✓ Branch 1 taken 2 times.
|
6 | for (int64_t i = 0; i < num_rows; ++i) { |
| 576 | 4 | const float* row = grads + i * embedding_dim; | |
| 577 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | row_grads.emplace_back(row, row + embedding_dim); |
| 578 | } | ||
| 579 |
1/2✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
|
2 | return UpdateParameter(table_name, keys, &row_grads); |
| 580 | 2 | } | |
| 581 | |||
| 582 | 10 | int DistributedGRPCParameterClient::InitEmbeddingTable( | |
| 583 | const std::string& table_name, | ||
| 584 | const recstore::EmbeddingTableConfig& config) { | ||
| 585 | 10 | std::vector<std::future<int>> futures; | |
| 586 |
2/2✓ Branch 4 taken 20 times.
✓ Branch 5 taken 10 times.
|
30 | for (auto& client : clients_) { |
| 587 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | futures.push_back( |
| 588 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
40 | std::async(std::launch::async, [&client, &table_name, &config]() { |
| 589 | 20 | return client->InitEmbeddingTable(table_name, config); | |
| 590 | })); | ||
| 591 | } | ||
| 592 | |||
| 593 |
2/2✓ Branch 5 taken 20 times.
✓ Branch 6 taken 10 times.
|
30 | for (auto& future : futures) { |
| 594 |
2/4✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 20 times.
|
20 | if (future.get() != 0) { |
| 595 | ✗ | LOG(ERROR) << "InitEmbeddingTable failed on one of the shards"; | |
| 596 | ✗ | return -1; | |
| 597 | } | ||
| 598 | } | ||
| 599 | |||
| 600 | 10 | return 0; | |
| 601 | 10 | } | |
| 602 | |||
| 603 | // Prefetch 接口实现 | ||
| 604 | 38 | uint64_t DistributedGRPCParameterClient::PrefetchParameter( | |
| 605 | const base::ConstArray<uint64_t>& keys) { | ||
| 606 | ✗ | auto cleanup_state = [this](const DistPrefetchState& state) { | |
| 607 | ✗ | for (const auto& shard_state : state.shard_states) { | |
| 608 | ✗ | if (shard_state.client_index < 0 || | |
| 609 | ✗ | shard_state.client_index >= static_cast<int>(clients_.size())) { | |
| 610 | ✗ | continue; | |
| 611 | } | ||
| 612 | ✗ | auto* client = clients_[shard_state.client_index].get(); | |
| 613 | ✗ | for (uint64_t child_prefetch_id : shard_state.child_prefetch_ids) { | |
| 614 | ✗ | client->WaitForPrefetch(child_prefetch_id); | |
| 615 | ✗ | std::vector<std::vector<float>> tmp; | |
| 616 | ✗ | client->GetPrefetchResult(child_prefetch_id, &tmp); | |
| 617 | ✗ | } | |
| 618 | } | ||
| 619 | ✗ | }; | |
| 620 | |||
| 621 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 38 times.
|
38 | if (keys.Size() == 0) { |
| 622 | ✗ | std::lock_guard<std::mutex> lk(prefetch_mu_); | |
| 623 | ✗ | uint64_t prefetch_id = next_prefetch_id_++; | |
| 624 | ✗ | auto state = std::make_shared<DistPrefetchState>(); | |
| 625 | ✗ | state->total_keys = 0; | |
| 626 | ✗ | prefetch_states_[prefetch_id] = state; | |
| 627 | ✗ | return prefetch_id; | |
| 628 | ✗ | } | |
| 629 | |||
| 630 |
1/2✓ Branch 2 taken 38 times.
✗ Branch 3 not taken.
|
38 | std::vector<std::vector<uint64_t>> shard_keys(num_shards_); |
| 631 |
1/2✓ Branch 2 taken 38 times.
✗ Branch 3 not taken.
|
38 | std::vector<std::vector<size_t>> shard_indices(num_shards_); |
| 632 |
2/2✓ Branch 1 taken 424 times.
✓ Branch 2 taken 38 times.
|
462 | for (size_t i = 0; i < keys.Size(); ++i) { |
| 633 |
1/2✓ Branch 2 taken 424 times.
✗ Branch 3 not taken.
|
424 | const int shard_id = GetShardId(keys[i]); |
| 634 |
1/2✓ Branch 3 taken 424 times.
✗ Branch 4 not taken.
|
424 | shard_keys[shard_id].push_back(keys[i]); |
| 635 |
1/2✓ Branch 2 taken 424 times.
✗ Branch 3 not taken.
|
424 | shard_indices[shard_id].push_back(i); |
| 636 | } | ||
| 637 | |||
| 638 |
1/2✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
|
38 | auto state = std::make_shared<DistPrefetchState>(); |
| 639 | 38 | state->total_keys = keys.Size(); | |
| 640 | |||
| 641 |
2/2✓ Branch 0 taken 76 times.
✓ Branch 1 taken 38 times.
|
114 | for (int shard_id = 0; shard_id < num_shards_; ++shard_id) { |
| 642 |
1/2✗ Branch 2 not taken.
✓ Branch 3 taken 76 times.
|
76 | if (shard_keys[shard_id].empty()) { |
| 643 | ✗ | continue; | |
| 644 | } | ||
| 645 | |||
| 646 |
1/2✓ Branch 1 taken 76 times.
✗ Branch 2 not taken.
|
76 | auto it = shard_to_client_index_.find(shard_id); |
| 647 |
1/2✗ Branch 2 not taken.
✓ Branch 3 taken 76 times.
|
76 | if (it == shard_to_client_index_.end()) { |
| 648 | ✗ | LOG(ERROR) << "No client found for shard " << shard_id; | |
| 649 | ✗ | cleanup_state(*state); | |
| 650 | ✗ | return 0; | |
| 651 | } | ||
| 652 | |||
| 653 | 76 | DistPrefetchShardState shard_state; | |
| 654 | 76 | shard_state.shard_id = shard_id; | |
| 655 | 76 | shard_state.client_index = it->second; | |
| 656 | 76 | shard_state.original_indices = std::move(shard_indices[shard_id]); | |
| 657 | |||
| 658 | 76 | const auto& skeys = shard_keys[shard_id]; | |
| 659 |
2/2✓ Branch 1 taken 104 times.
✓ Branch 2 taken 76 times.
|
180 | for (size_t start = 0; start < skeys.size(); |
| 660 | 104 | start += static_cast<size_t>(max_keys_per_request_)) { | |
| 661 | 104 | size_t end = std::min( | |
| 662 | 104 | start + static_cast<size_t>(max_keys_per_request_), skeys.size()); | |
| 663 | base::ConstArray<uint64_t> chunk( | ||
| 664 | 104 | skeys.data() + start, static_cast<int>(end - start)); | |
| 665 | uint64_t child_prefetch_id = | ||
| 666 |
1/2✓ Branch 3 taken 104 times.
✗ Branch 4 not taken.
|
104 | clients_[shard_state.client_index]->PrefetchParameter(chunk); |
| 667 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 104 times.
|
104 | if (child_prefetch_id == 0) { |
| 668 | ✗ | LOG(ERROR) << "PrefetchParameter failed for shard " << shard_id; | |
| 669 | ✗ | cleanup_state(*state); | |
| 670 | ✗ | return 0; | |
| 671 | } | ||
| 672 |
1/2✓ Branch 1 taken 104 times.
✗ Branch 2 not taken.
|
104 | shard_state.child_prefetch_ids.push_back(child_prefetch_id); |
| 673 |
1/2✓ Branch 1 taken 104 times.
✗ Branch 2 not taken.
|
104 | shard_state.chunk_sizes.push_back(static_cast<int>(end - start)); |
| 674 | } | ||
| 675 |
1/2✓ Branch 3 taken 76 times.
✗ Branch 4 not taken.
|
76 | state->shard_states.push_back(std::move(shard_state)); |
| 676 |
1/2✓ Branch 1 taken 76 times.
✗ Branch 2 not taken.
|
76 | } |
| 677 | |||
| 678 |
1/2✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
|
38 | std::lock_guard<std::mutex> lk(prefetch_mu_); |
| 679 | 38 | uint64_t prefetch_id = next_prefetch_id_++; | |
| 680 |
1/2✓ Branch 2 taken 38 times.
✗ Branch 3 not taken.
|
38 | prefetch_states_[prefetch_id] = std::move(state); |
| 681 | 38 | return prefetch_id; | |
| 682 | 38 | } | |
| 683 | |||
| 684 | 4 | bool DistributedGRPCParameterClient::IsPrefetchDone(uint64_t prefetch_id) { | |
| 685 | 4 | std::shared_ptr<DistPrefetchState> state; | |
| 686 | { | ||
| 687 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | std::lock_guard<std::mutex> lk(prefetch_mu_); |
| 688 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | auto it = prefetch_states_.find(prefetch_id); |
| 689 |
2/2✓ Branch 2 taken 2 times.
✓ Branch 3 taken 2 times.
|
4 | if (it == prefetch_states_.end()) { |
| 690 |
4/8✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 11 not taken.
|
2 | LOG(ERROR) << "Invalid prefetch_id: " << prefetch_id; |
| 691 | 2 | return false; | |
| 692 | } | ||
| 693 | 2 | state = it->second; | |
| 694 |
2/2✓ Branch 1 taken 2 times.
✓ Branch 2 taken 2 times.
|
4 | } |
| 695 | |||
| 696 |
2/2✓ Branch 6 taken 4 times.
✓ Branch 7 taken 2 times.
|
6 | for (const auto& shard_state : state->shard_states) { |
| 697 | 4 | auto* client = clients_[shard_state.client_index].get(); | |
| 698 |
2/2✓ Branch 5 taken 4 times.
✓ Branch 6 taken 4 times.
|
8 | for (uint64_t child_prefetch_id : shard_state.child_prefetch_ids) { |
| 699 |
2/4✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
|
4 | if (!client->IsPrefetchDone(child_prefetch_id)) { |
| 700 | ✗ | return false; | |
| 701 | } | ||
| 702 | } | ||
| 703 | } | ||
| 704 | 2 | return true; | |
| 705 | 4 | } | |
| 706 | |||
| 707 | 74 | void DistributedGRPCParameterClient::WaitForPrefetch(uint64_t prefetch_id) { | |
| 708 | 74 | std::shared_ptr<DistPrefetchState> state; | |
| 709 | { | ||
| 710 |
1/2✓ Branch 1 taken 74 times.
✗ Branch 2 not taken.
|
74 | std::lock_guard<std::mutex> lk(prefetch_mu_); |
| 711 |
1/2✓ Branch 1 taken 74 times.
✗ Branch 2 not taken.
|
74 | auto it = prefetch_states_.find(prefetch_id); |
| 712 |
1/2✗ Branch 2 not taken.
✓ Branch 3 taken 74 times.
|
74 | if (it == prefetch_states_.end()) { |
| 713 | ✗ | LOG(ERROR) << "Invalid prefetch_id: " << prefetch_id; | |
| 714 | ✗ | return; | |
| 715 | } | ||
| 716 | 74 | state = it->second; | |
| 717 |
1/2✓ Branch 1 taken 74 times.
✗ Branch 2 not taken.
|
74 | } |
| 718 | |||
| 719 |
2/2✓ Branch 6 taken 148 times.
✓ Branch 7 taken 74 times.
|
222 | for (const auto& shard_state : state->shard_states) { |
| 720 | 148 | auto* client = clients_[shard_state.client_index].get(); | |
| 721 |
2/2✓ Branch 5 taken 204 times.
✓ Branch 6 taken 148 times.
|
352 | for (uint64_t child_prefetch_id : shard_state.child_prefetch_ids) { |
| 722 |
1/2✓ Branch 1 taken 204 times.
✗ Branch 2 not taken.
|
204 | client->WaitForPrefetch(child_prefetch_id); |
| 723 | } | ||
| 724 | } | ||
| 725 |
1/2✓ Branch 1 taken 74 times.
✗ Branch 2 not taken.
|
74 | } |
| 726 | |||
| 727 | 42 | bool DistributedGRPCParameterClient::GetPrefetchResult( | |
| 728 | uint64_t prefetch_id, std::vector<std::vector<float>>* values) { | ||
| 729 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 42 times.
|
42 | if (values == nullptr) { |
| 730 | ✗ | LOG(ERROR) << "GetPrefetchResult output pointer is null"; | |
| 731 | ✗ | return false; | |
| 732 | } | ||
| 733 | |||
| 734 | 42 | std::shared_ptr<DistPrefetchState> state; | |
| 735 | { | ||
| 736 |
1/2✓ Branch 1 taken 42 times.
✗ Branch 2 not taken.
|
42 | std::lock_guard<std::mutex> lk(prefetch_mu_); |
| 737 |
1/2✓ Branch 1 taken 42 times.
✗ Branch 2 not taken.
|
42 | auto it = prefetch_states_.find(prefetch_id); |
| 738 |
2/2✓ Branch 2 taken 4 times.
✓ Branch 3 taken 38 times.
|
42 | if (it == prefetch_states_.end()) { |
| 739 |
4/8✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 4 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 4 times.
✗ Branch 11 not taken.
|
4 | LOG(ERROR) << "Invalid prefetch_id: " << prefetch_id; |
| 740 | 4 | return false; | |
| 741 | } | ||
| 742 | 38 | state = it->second; | |
| 743 |
2/2✓ Branch 1 taken 38 times.
✓ Branch 2 taken 4 times.
|
42 | } |
| 744 | |||
| 745 | // Ensure all child RPCs are completed before consuming payloads. | ||
| 746 |
1/2✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
|
38 | WaitForPrefetch(prefetch_id); |
| 747 | |||
| 748 | 38 | values->clear(); | |
| 749 |
1/2✓ Branch 2 taken 38 times.
✗ Branch 3 not taken.
|
38 | values->resize(state->total_keys); |
| 750 | |||
| 751 | 38 | bool ok_all = true; | |
| 752 |
2/2✓ Branch 6 taken 76 times.
✓ Branch 7 taken 38 times.
|
114 | for (const auto& shard_state : state->shard_states) { |
| 753 | 76 | auto* client = clients_[shard_state.client_index].get(); | |
| 754 | 76 | size_t shard_offset = 0; | |
| 755 |
2/2✓ Branch 1 taken 104 times.
✓ Branch 2 taken 76 times.
|
180 | for (size_t i = 0; i < shard_state.child_prefetch_ids.size(); ++i) { |
| 756 | 104 | std::vector<std::vector<float>> chunk_values; | |
| 757 |
2/4✓ Branch 1 taken 104 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 104 times.
|
104 | if (!client->GetPrefetchResult( |
| 758 | 104 | shard_state.child_prefetch_ids[i], &chunk_values)) { | |
| 759 | ✗ | ok_all = false; | |
| 760 | ✗ | break; | |
| 761 | } | ||
| 762 | const int expected = | ||
| 763 | 104 | (i < shard_state.chunk_sizes.size() | |
| 764 |
1/2✓ Branch 0 taken 104 times.
✗ Branch 1 not taken.
|
104 | ? shard_state.chunk_sizes[i] |
| 765 | 104 | : -1); | |
| 766 |
3/6✓ Branch 0 taken 104 times.
✗ Branch 1 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 104 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 104 times.
|
104 | if (expected >= 0 && static_cast<int>(chunk_values.size()) != expected) { |
| 767 | ✗ | LOG(ERROR) << "Prefetch chunk size mismatch: got " | |
| 768 | ✗ | << chunk_values.size() << ", expected " << expected; | |
| 769 | ✗ | ok_all = false; | |
| 770 | ✗ | break; | |
| 771 | } | ||
| 772 |
2/2✓ Branch 5 taken 424 times.
✓ Branch 6 taken 104 times.
|
528 | for (const auto& row : chunk_values) { |
| 773 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 424 times.
|
424 | if (shard_offset >= shard_state.original_indices.size()) { |
| 774 | ✗ | LOG(ERROR) << "Prefetch result overflow in shard " | |
| 775 | ✗ | << shard_state.shard_id; | |
| 776 | ✗ | ok_all = false; | |
| 777 | ✗ | break; | |
| 778 | } | ||
| 779 |
1/2✓ Branch 3 taken 424 times.
✗ Branch 4 not taken.
|
424 | (*values)[shard_state.original_indices[shard_offset++]] = row; |
| 780 | } | ||
| 781 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 104 times.
|
104 | if (!ok_all) { |
| 782 | ✗ | break; | |
| 783 | } | ||
| 784 |
1/2✓ Branch 1 taken 104 times.
✗ Branch 2 not taken.
|
104 | } |
| 785 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 76 times.
|
76 | if (!ok_all) { |
| 786 | ✗ | break; | |
| 787 | } | ||
| 788 | } | ||
| 789 | |||
| 790 | { | ||
| 791 |
1/2✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
|
38 | std::lock_guard<std::mutex> lk(prefetch_mu_); |
| 792 |
1/2✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
|
38 | prefetch_states_.erase(prefetch_id); |
| 793 | 38 | } | |
| 794 | 38 | return ok_all; | |
| 795 | 42 | } | |
| 796 | |||
| 797 | 6 | bool DistributedGRPCParameterClient::GetPrefetchResultFlat( | |
| 798 | uint64_t prefetch_id, | ||
| 799 | std::vector<float>* values, | ||
| 800 | int64_t* num_rows, | ||
| 801 | int64_t embedding_dim) { | ||
| 802 |
2/4✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 6 times.
|
6 | if (values == nullptr || num_rows == nullptr) { |
| 803 | ✗ | LOG(ERROR) << "GetPrefetchResultFlat output pointer is null"; | |
| 804 | ✗ | return false; | |
| 805 | } | ||
| 806 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
|
6 | if (embedding_dim <= 0) { |
| 807 | ✗ | LOG(ERROR) << "GetPrefetchResultFlat invalid embedding_dim: " | |
| 808 | ✗ | << embedding_dim; | |
| 809 | ✗ | return false; | |
| 810 | } | ||
| 811 | |||
| 812 | 6 | std::vector<std::vector<float>> merged_values; | |
| 813 |
3/4✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
✓ Branch 4 taken 4 times.
|
6 | if (!GetPrefetchResult(prefetch_id, &merged_values)) { |
| 814 | 2 | return false; | |
| 815 | } | ||
| 816 | |||
| 817 | 4 | *num_rows = static_cast<int64_t>(merged_values.size()); | |
| 818 | 4 | values->assign( | |
| 819 | 4 | static_cast<size_t>(*num_rows) * static_cast<size_t>(embedding_dim), | |
| 820 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | 0.0f); |
| 821 |
2/2✓ Branch 1 taken 24 times.
✓ Branch 2 taken 4 times.
|
28 | for (size_t i = 0; i < merged_values.size(); ++i) { |
| 822 | 24 | const auto& row = merged_values[i]; | |
| 823 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 24 times.
|
24 | if (row.empty()) { |
| 824 | ✗ | continue; | |
| 825 | } | ||
| 826 | const int64_t copy_d = | ||
| 827 | 24 | std::min<int64_t>(embedding_dim, static_cast<int64_t>(row.size())); | |
| 828 | 24 | std::memcpy(values->data() + i * static_cast<size_t>(embedding_dim), | |
| 829 | 24 | row.data(), | |
| 830 | 24 | static_cast<size_t>(copy_d) * sizeof(float)); | |
| 831 | } | ||
| 832 | 4 | return true; | |
| 833 | 6 | } | |
| 834 | |||
| 835 | } // namespace recstore | ||
| 836 |