ps/brpc/dist_brpc_ps_client.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #include "dist_brpc_ps_client.h" | ||
| 2 | |||
| 3 | #include <algorithm> | ||
| 4 | #include <cstring> | ||
| 5 | #include <fstream> | ||
| 6 | #include <future> | ||
| 7 | #include <thread> | ||
| 8 | |||
| 9 | #include "base/factory.h" | ||
| 10 | #include "base/flatc.h" | ||
| 11 | #include "base/log.h" | ||
| 12 | #include "base/timer.h" | ||
| 13 | |||
| 14 | #ifdef ENABLE_PERF_REPORT | ||
| 15 | # include <chrono> | ||
| 16 | # include "base/report/report_client.h" | ||
| 17 | #endif | ||
| 18 | |||
| 19 | using recstoreps_brpc::CommandRequest; | ||
| 20 | using recstoreps_brpc::CommandResponse; | ||
| 21 | using recstoreps_brpc::GetParameterRequest; | ||
| 22 | using recstoreps_brpc::GetParameterResponse; | ||
| 23 | using recstoreps_brpc::PSCommand; | ||
| 24 | using recstoreps_brpc::PutParameterRequest; | ||
| 25 | using recstoreps_brpc::PutParameterResponse; | ||
| 26 | |||
| 27 | namespace recstore { | ||
| 28 | |||
| 29 | FACTORY_REGISTER( | ||
| 30 | BasePSClient, distributed_brpc, DistributedBRPCParameterClient, json); | ||
| 31 | |||
| 32 | 10 | DistributedBRPCParameterClient::DistributedBRPCParameterClient(json config) | |
| 33 |
1/2✓ Branch 2 taken 10 times.
✗ Branch 3 not taken.
|
10 | : BasePSClient(config) { |
| 34 | 10 | json client_config; | |
| 35 | |||
| 36 |
2/4✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 10 times.
✗ Branch 4 not taken.
|
10 | if (config.contains("distributed_client")) { |
| 37 |
2/4✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 10 times.
✗ Branch 5 not taken.
|
20 | LOG(INFO) << "Detected recstore config format, extracting " |
| 38 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | "distributed_client section"; |
| 39 |
2/4✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 10 times.
✗ Branch 5 not taken.
|
10 | client_config = config["distributed_client"]; |
| 40 | } else { | ||
| 41 | ✗ | LOG(FATAL) | |
| 42 | << "Invalid config format. Expected either recstore config with " | ||
| 43 | "'distributed_client' section " | ||
| 44 | ✗ | << "or direct client config with 'servers' and 'num_shards' fields"; | |
| 45 | } | ||
| 46 | |||
| 47 | // 验证必要字段 | ||
| 48 |
3/6✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 10 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 10 times.
|
20 | if (!client_config.contains("servers") || |
| 49 |
2/4✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 10 times.
|
10 | !client_config["servers"].is_array()) { |
| 50 | ✗ | LOG(FATAL) | |
| 51 | ✗ | << "Missing or invalid 'servers' field in distributed client config"; | |
| 52 | } | ||
| 53 | |||
| 54 |
3/6✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 10 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 10 times.
|
20 | if (!client_config.contains("num_shards") || |
| 55 |
2/4✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 10 times.
|
10 | !client_config["num_shards"].is_number_integer()) { |
| 56 | ✗ | LOG(FATAL) | |
| 57 | ✗ | << "Missing or invalid 'num_shards' field in distributed client config"; | |
| 58 | } | ||
| 59 | |||
| 60 |
2/4✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 10 times.
✗ Branch 5 not taken.
|
10 | num_shards_ = client_config["num_shards"].get<int>(); |
| 61 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | max_keys_per_request_ = client_config.value("max_keys_per_request", 500); |
| 62 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | hash_method_ = client_config.value("hash_method", "city_hash"); |
| 63 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 10 times.
|
10 | if (max_keys_per_request_ <= 0) { |
| 64 | ✗ | LOG(FATAL) << "Invalid max_keys_per_request: " << max_keys_per_request_ | |
| 65 | ✗ | << ", must be > 0"; | |
| 66 | } | ||
| 67 | |||
| 68 | // 解析服务器配置 | ||
| 69 |
2/4✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 10 times.
✗ Branch 5 not taken.
|
10 | auto servers = client_config["servers"]; |
| 70 |
1/2✓ Branch 2 taken 10 times.
✗ Branch 3 not taken.
|
10 | server_configs_.reserve(servers.size()); |
| 71 | |||
| 72 |
2/2✓ Branch 1 taken 20 times.
✓ Branch 2 taken 10 times.
|
30 | for (size_t i = 0; i < servers.size(); ++i) { |
| 73 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | const auto& server = servers[i]; |
| 74 |
5/10✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 20 times.
✗ Branch 4 not taken.
✓ Branch 6 taken 20 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 20 times.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✓ Branch 11 taken 20 times.
|
40 | if (!server.contains("host") || !server.contains("port") || |
| 75 |
2/4✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 20 times.
|
20 | !server.contains("shard")) { |
| 76 | ✗ | LOG(FATAL) << "Server config " << i | |
| 77 | ✗ | << " missing required fields (host, port, shard)"; | |
| 78 | } | ||
| 79 | |||
| 80 | 20 | ServerConfig cfg; | |
| 81 |
2/4✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 20 times.
✗ Branch 5 not taken.
|
20 | cfg.host = server["host"].get<std::string>(); |
| 82 |
2/4✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 20 times.
✗ Branch 5 not taken.
|
20 | cfg.port = server["port"].get<int>(); |
| 83 |
2/4✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 20 times.
✗ Branch 5 not taken.
|
20 | cfg.shard = server["shard"].get<int>(); |
| 84 | |||
| 85 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | server_configs_.push_back(cfg); |
| 86 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | shard_to_client_index_[cfg.shard] = i; |
| 87 | 20 | } | |
| 88 | |||
| 89 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 10 times.
|
10 | if (server_configs_.size() != static_cast<size_t>(num_shards_)) { |
| 90 | ✗ | LOG(WARNING) << "Number of servers (" << server_configs_.size() | |
| 91 | ✗ | << ") doesn't match num_shards (" << num_shards_ << ")"; | |
| 92 | } | ||
| 93 | |||
| 94 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | partitioned_key_buffer_.resize(num_shards_); |
| 95 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | key_index_mapping_.resize(num_shards_); |
| 96 | |||
| 97 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | InitializeClients(); |
| 98 | |||
| 99 |
3/6✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 10 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 10 times.
✗ Branch 8 not taken.
|
20 | LOG(INFO) << "Initialized DistributedBRPCParameterClient with " << num_shards_ |
| 100 |
3/6✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 10 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 10 times.
✗ Branch 8 not taken.
|
10 | << " shards, hash method: " << hash_method_; |
| 101 | 10 | } | |
| 102 | |||
| 103 | 14 | DistributedBRPCParameterClient::~DistributedBRPCParameterClient() {} | |
| 104 | |||
| 105 | 10 | void DistributedBRPCParameterClient::InitializeClients() { | |
| 106 | 10 | clients_.clear(); | |
| 107 | 10 | clients_.reserve(server_configs_.size()); | |
| 108 | |||
| 109 |
2/2✓ Branch 5 taken 20 times.
✓ Branch 6 taken 10 times.
|
30 | for (const auto& server_config : server_configs_) { |
| 110 | // 为每个服务器创建独立的配置 | ||
| 111 | 20 | json client_config = {{"host", server_config.host}, | |
| 112 | 20 | {"port", server_config.port}, | |
| 113 |
8/16✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 20 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 20 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 20 times.
✗ Branch 11 not taken.
✓ Branch 14 taken 20 times.
✗ Branch 15 not taken.
✓ Branch 17 taken 20 times.
✗ Branch 18 not taken.
✓ Branch 21 taken 20 times.
✗ Branch 22 not taken.
✓ Branch 24 taken 20 times.
✗ Branch 25 not taken.
|
260 | {"shard", server_config.shard}}; |
| 114 | |||
| 115 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | auto client = std::make_unique<BRPCParameterClient>(client_config); |
| 116 |
1/2✓ Branch 2 taken 20 times.
✗ Branch 3 not taken.
|
20 | clients_.push_back(std::move(client)); |
| 117 | |||
| 118 |
3/6✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 20 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 20 times.
✗ Branch 8 not taken.
|
40 | LOG(INFO) << "Created bRPC client for shard " << server_config.shard |
| 119 |
5/10✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 20 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 20 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 20 times.
✗ Branch 11 not taken.
✓ Branch 13 taken 20 times.
✗ Branch 14 not taken.
|
20 | << " at " << server_config.host << ":" << server_config.port; |
| 120 | 20 | } | |
| 121 | 10 | } | |
| 122 | |||
| 123 | 448 | int DistributedBRPCParameterClient::GetShardId(uint64_t key) const { | |
| 124 |
1/2✓ Branch 1 taken 448 times.
✗ Branch 2 not taken.
|
448 | if (hash_method_ == "city_hash") { |
| 125 | 448 | return GetHash(key) % num_shards_; | |
| 126 | ✗ | } else if (hash_method_ == "simple_mod") { | |
| 127 | ✗ | return key % num_shards_; | |
| 128 | } else { | ||
| 129 | ✗ | LOG(ERROR) << "Unknown hash method: " << hash_method_ | |
| 130 | ✗ | << ", using city_hash"; | |
| 131 | ✗ | return GetHash(key) % num_shards_; | |
| 132 | } | ||
| 133 | } | ||
| 134 | |||
| 135 | 20 | void DistributedBRPCParameterClient::PartitionKeys( | |
| 136 | const base::ConstArray<uint64_t>& keys, | ||
| 137 | std::vector<std::vector<uint64_t>>& partitioned_keys) const { | ||
| 138 |
2/2✓ Branch 5 taken 40 times.
✓ Branch 6 taken 20 times.
|
60 | for (auto& partition : partitioned_key_buffer_) { |
| 139 | 40 | partition.clear(); | |
| 140 | } | ||
| 141 |
2/2✓ Branch 5 taken 40 times.
✓ Branch 6 taken 20 times.
|
60 | for (auto& mapping : key_index_mapping_) { |
| 142 | 40 | mapping.clear(); | |
| 143 | } | ||
| 144 | |||
| 145 |
2/2✓ Branch 1 taken 448 times.
✓ Branch 2 taken 20 times.
|
468 | for (size_t i = 0; i < keys.Size(); ++i) { |
| 146 | 448 | uint64_t key = keys[i]; | |
| 147 |
1/2✓ Branch 1 taken 448 times.
✗ Branch 2 not taken.
|
448 | int shard_id = GetShardId(key); |
| 148 | |||
| 149 |
1/2✓ Branch 2 taken 448 times.
✗ Branch 3 not taken.
|
448 | partitioned_key_buffer_[shard_id].push_back(key); |
| 150 |
1/2✓ Branch 2 taken 448 times.
✗ Branch 3 not taken.
|
448 | key_index_mapping_[shard_id].push_back(i); |
| 151 | } | ||
| 152 | |||
| 153 | 20 | partitioned_keys = partitioned_key_buffer_; | |
| 154 | 20 | } | |
| 155 | |||
| 156 | 14 | bool DistributedBRPCParameterClient::GetParameter( | |
| 157 | const base::ConstArray<uint64_t>& keys, | ||
| 158 | std::vector<std::vector<float>>* values) { | ||
| 159 | #ifdef ENABLE_PERF_REPORT | ||
| 160 | auto start_time = std::chrono::high_resolution_clock::now(); | ||
| 161 | #endif | ||
| 162 | |||
| 163 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 14 times.
|
14 | if (keys.Size() == 0) { |
| 164 | ✗ | values->clear(); | |
| 165 | ✗ | return true; | |
| 166 | } | ||
| 167 | |||
| 168 |
2/4✓ Branch 2 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 14 times.
✗ Branch 6 not taken.
|
28 | xmh::Timer timer("DistributedBRPCParameterClient::GetParameter"); |
| 169 | |||
| 170 | 14 | std::vector<std::vector<uint64_t>> partitioned_keys; | |
| 171 |
1/2✓ Branch 1 taken 14 times.
✗ Branch 2 not taken.
|
14 | PartitionKeys(keys, partitioned_keys); |
| 172 | |||
| 173 | 14 | std::vector<std::future<int>> futures; | |
| 174 |
1/2✓ Branch 2 taken 14 times.
✗ Branch 3 not taken.
|
14 | std::vector<std::vector<std::vector<float>>> partitioned_results(num_shards_); |
| 175 | |||
| 176 |
2/2✓ Branch 0 taken 28 times.
✓ Branch 1 taken 14 times.
|
42 | for (int shard_id = 0; shard_id < num_shards_; ++shard_id) { |
| 177 |
1/2✗ Branch 2 not taken.
✓ Branch 3 taken 28 times.
|
28 | if (partitioned_keys[shard_id].empty()) { |
| 178 | ✗ | continue; | |
| 179 | } | ||
| 180 | |||
| 181 |
1/2✓ Branch 1 taken 28 times.
✗ Branch 2 not taken.
|
28 | auto it = shard_to_client_index_.find(shard_id); |
| 182 |
1/2✗ Branch 2 not taken.
✓ Branch 3 taken 28 times.
|
28 | if (it == shard_to_client_index_.end()) { |
| 183 | ✗ | LOG(ERROR) << "No client found for shard " << shard_id; | |
| 184 | ✗ | return false; | |
| 185 | } | ||
| 186 | |||
| 187 | 28 | int client_index = it->second; | |
| 188 | 28 | auto* client = clients_[client_index].get(); | |
| 189 | |||
| 190 | // 异步请求 | ||
| 191 |
2/4✓ Branch 1 taken 28 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 28 times.
✗ Branch 5 not taken.
|
28 | futures.push_back(std::async( |
| 192 | 56 | std::launch::async, [=, &partitioned_keys, &partitioned_results]() { | |
| 193 | 28 | const auto& shard_keys_vec = partitioned_keys[shard_id]; | |
| 194 | 28 | auto& shard_result_vec = partitioned_results[shard_id]; | |
| 195 | 28 | shard_result_vec.clear(); | |
| 196 | 28 | shard_result_vec.reserve(shard_keys_vec.size()); | |
| 197 | |||
| 198 |
2/2✓ Branch 1 taken 30 times.
✓ Branch 2 taken 28 times.
|
58 | for (size_t start = 0; start < shard_keys_vec.size(); |
| 199 | 30 | start += static_cast<size_t>(max_keys_per_request_)) { | |
| 200 | size_t end = | ||
| 201 | 60 | std::min(start + static_cast<size_t>(max_keys_per_request_), | |
| 202 | 30 | shard_keys_vec.size()); | |
| 203 | base::ConstArray<uint64_t> shard_chunk( | ||
| 204 | 30 | shard_keys_vec.data() + start, static_cast<int>(end - start)); | |
| 205 | 30 | std::vector<std::vector<float>> chunk_result; | |
| 206 |
2/4✓ Branch 1 taken 30 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 30 times.
|
30 | if (!client->GetParameter(shard_chunk, &chunk_result)) { |
| 207 | ✗ | return 0; | |
| 208 | } | ||
| 209 |
1/2✓ Branch 5 taken 30 times.
✗ Branch 6 not taken.
|
30 | shard_result_vec.insert(shard_result_vec.end(), |
| 210 | chunk_result.begin(), | ||
| 211 | chunk_result.end()); | ||
| 212 |
1/2✓ Branch 1 taken 30 times.
✗ Branch 2 not taken.
|
30 | } |
| 213 | 28 | return 1; | |
| 214 | })); | ||
| 215 | } | ||
| 216 | |||
| 217 | // 等待所有请求完成 | ||
| 218 |
2/2✓ Branch 5 taken 28 times.
✓ Branch 6 taken 14 times.
|
42 | for (auto& future : futures) { |
| 219 |
2/4✓ Branch 1 taken 28 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 28 times.
|
28 | if (!future.get()) { |
| 220 | ✗ | LOG(ERROR) << "Failed to get parameters from one of the shards"; | |
| 221 | ✗ | return false; | |
| 222 | } | ||
| 223 | } | ||
| 224 | |||
| 225 | // 合并结果 | ||
| 226 |
1/2✓ Branch 1 taken 14 times.
✗ Branch 2 not taken.
|
14 | MergeResults(keys, partitioned_results, values); |
| 227 | |||
| 228 | #ifdef ENABLE_PERF_REPORT | ||
| 229 | auto end_time = std::chrono::high_resolution_clock::now(); | ||
| 230 | auto duration = | ||
| 231 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 232 | end_time - start_time) | ||
| 233 | .count(); | ||
| 234 | double start_us = | ||
| 235 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 236 | start_time.time_since_epoch()) | ||
| 237 | .count(); | ||
| 238 | FlameGraphData fg_data = { | ||
| 239 | "dist_client::GetParameter", | ||
| 240 | start_us, | ||
| 241 | 1, // level | ||
| 242 | static_cast<double>(duration), | ||
| 243 | static_cast<double>(duration)}; | ||
| 244 | std::string unique_id = | ||
| 245 | "embread_debug" + std::to_string(recstore::g_trace_id); | ||
| 246 | report_flame_graph("emb_read_flame_map", unique_id.c_str(), fg_data); | ||
| 247 | #endif | ||
| 248 | |||
| 249 | 14 | return true; | |
| 250 | 14 | } | |
| 251 | |||
| 252 | 14 | void DistributedBRPCParameterClient::MergeResults( | |
| 253 | const base::ConstArray<uint64_t>& keys, | ||
| 254 | const std::vector<std::vector<std::vector<float>>>& partitioned_results, | ||
| 255 | std::vector<std::vector<float>>* values) const { | ||
| 256 | 14 | values->clear(); | |
| 257 | 14 | values->resize(keys.Size()); | |
| 258 | |||
| 259 | // 重建 key -> index 映射 | ||
| 260 |
2/2✓ Branch 0 taken 28 times.
✓ Branch 1 taken 14 times.
|
42 | for (int shard_id = 0; shard_id < num_shards_; ++shard_id) { |
| 261 |
2/2✓ Branch 2 taken 236 times.
✓ Branch 3 taken 28 times.
|
264 | for (size_t i = 0; i < key_index_mapping_[shard_id].size(); ++i) { |
| 262 | 236 | size_t original_index = key_index_mapping_[shard_id][i]; | |
| 263 |
1/2✓ Branch 2 taken 236 times.
✗ Branch 3 not taken.
|
236 | if (i < partitioned_results[shard_id].size()) { |
| 264 | 236 | (*values)[original_index] = partitioned_results[shard_id][i]; | |
| 265 | } | ||
| 266 | } | ||
| 267 | } | ||
| 268 | 14 | } | |
| 269 | |||
| 270 | ✗ | void DistributedBRPCParameterClient::MergeResultsToArray( | |
| 271 | const base::ConstArray<uint64_t>& keys, | ||
| 272 | const std::vector<std::vector<std::vector<float>>>& partitioned_results, | ||
| 273 | float* values) const { | ||
| 274 | ✗ | int emb_dim = 0; | |
| 275 | ✗ | for (int shard_id = 0; shard_id < num_shards_; ++shard_id) { | |
| 276 | ✗ | if (!partitioned_results[shard_id].empty() && | |
| 277 | ✗ | !partitioned_results[shard_id][0].empty()) { | |
| 278 | ✗ | emb_dim = partitioned_results[shard_id][0].size(); | |
| 279 | ✗ | break; | |
| 280 | } | ||
| 281 | } | ||
| 282 | |||
| 283 | ✗ | if (emb_dim == 0) { | |
| 284 | ✗ | LOG(WARNING) << "No valid embeddings found"; | |
| 285 | ✗ | return; | |
| 286 | } | ||
| 287 | |||
| 288 | // 合并结果到连续内存 | ||
| 289 | ✗ | for (int shard_id = 0; shard_id < num_shards_; ++shard_id) { | |
| 290 | ✗ | for (size_t i = 0; i < key_index_mapping_[shard_id].size(); ++i) { | |
| 291 | ✗ | size_t original_index = key_index_mapping_[shard_id][i]; | |
| 292 | ✗ | if (i < partitioned_results[shard_id].size()) { | |
| 293 | ✗ | const auto& embedding = partitioned_results[shard_id][i]; | |
| 294 | ✗ | std::copy(embedding.begin(), | |
| 295 | embedding.end(), | ||
| 296 | ✗ | values + original_index * emb_dim); | |
| 297 | } | ||
| 298 | } | ||
| 299 | } | ||
| 300 | } | ||
| 301 | |||
| 302 | // 实现 BasePSClient 接口 | ||
| 303 | ✗ | int DistributedBRPCParameterClient::GetParameter( | |
| 304 | const base::ConstArray<uint64_t>& keys, float* values) { | ||
| 305 | ✗ | std::vector<std::vector<float>> result_vectors; | |
| 306 | ✗ | bool success = GetParameter(keys, &result_vectors); | |
| 307 | |||
| 308 | ✗ | if (!success) { | |
| 309 | ✗ | return -1; | |
| 310 | } | ||
| 311 | |||
| 312 | ✗ | if (keys.Size() == 0) { | |
| 313 | ✗ | return 0; | |
| 314 | } | ||
| 315 | ✗ | int emb_dim = 0; | |
| 316 | ✗ | for (const auto& row : result_vectors) { | |
| 317 | ✗ | if (!row.empty()) { | |
| 318 | ✗ | emb_dim = static_cast<int>(row.size()); | |
| 319 | ✗ | break; | |
| 320 | } | ||
| 321 | } | ||
| 322 | ✗ | if (emb_dim == 0) { | |
| 323 | ✗ | LOG(WARNING) << "No valid embeddings found"; | |
| 324 | ✗ | return 0; | |
| 325 | } | ||
| 326 | |||
| 327 | ✗ | for (size_t i = 0; i < result_vectors.size(); ++i) { | |
| 328 | ✗ | const auto& row = result_vectors[i]; | |
| 329 | ✗ | if (row.empty()) { | |
| 330 | ✗ | continue; | |
| 331 | } | ||
| 332 | ✗ | std::copy(row.begin(), row.end(), values + i * emb_dim); | |
| 333 | } | ||
| 334 | ✗ | return 0; | |
| 335 | ✗ | } | |
| 336 | |||
| 337 | ✗ | int DistributedBRPCParameterClient::AsyncGetParameter( | |
| 338 | const base::ConstArray<uint64_t>& keys, float* values) { | ||
| 339 | ✗ | return GetParameter(keys, values); | |
| 340 | } | ||
| 341 | |||
| 342 | 6 | int DistributedBRPCParameterClient::PutParameter( | |
| 343 | const base::ConstArray<uint64_t>& keys, | ||
| 344 | const std::vector<std::vector<float>>& values) { | ||
| 345 |
1/2✗ Branch 2 not taken.
✓ Branch 3 taken 6 times.
|
6 | if (keys.Size() != values.size()) { |
| 346 | ✗ | LOG(ERROR) << "Keys and values size mismatch: " << keys.Size() << " vs " | |
| 347 | ✗ | << values.size(); | |
| 348 | ✗ | return -1; | |
| 349 | } | ||
| 350 | |||
| 351 | 6 | std::vector<std::vector<uint64_t>> partitioned_keys; | |
| 352 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | PartitionKeys(keys, partitioned_keys); |
| 353 | |||
| 354 |
1/2✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
|
6 | std::vector<std::vector<std::vector<float>>> partitioned_values(num_shards_); |
| 355 |
2/2✓ Branch 0 taken 12 times.
✓ Branch 1 taken 6 times.
|
18 | for (int shard_id = 0; shard_id < num_shards_; ++shard_id) { |
| 356 |
2/2✓ Branch 2 taken 212 times.
✓ Branch 3 taken 12 times.
|
224 | for (size_t i = 0; i < key_index_mapping_[shard_id].size(); ++i) { |
| 357 | 212 | size_t original_index = key_index_mapping_[shard_id][i]; | |
| 358 |
1/2✓ Branch 3 taken 212 times.
✗ Branch 4 not taken.
|
212 | partitioned_values[shard_id].push_back(values[original_index]); |
| 359 | } | ||
| 360 | } | ||
| 361 | |||
| 362 | // 并发 put 到各个分片 | ||
| 363 | 6 | std::vector<std::future<int>> futures; | |
| 364 | |||
| 365 |
2/2✓ Branch 0 taken 12 times.
✓ Branch 1 taken 6 times.
|
18 | for (int shard_id = 0; shard_id < num_shards_; ++shard_id) { |
| 366 |
1/2✗ Branch 2 not taken.
✓ Branch 3 taken 12 times.
|
12 | if (partitioned_keys[shard_id].empty()) { |
| 367 | ✗ | continue; | |
| 368 | } | ||
| 369 | |||
| 370 |
1/2✓ Branch 1 taken 12 times.
✗ Branch 2 not taken.
|
12 | auto it = shard_to_client_index_.find(shard_id); |
| 371 |
1/2✗ Branch 2 not taken.
✓ Branch 3 taken 12 times.
|
12 | if (it == shard_to_client_index_.end()) { |
| 372 | ✗ | LOG(ERROR) << "No client found for shard " << shard_id; | |
| 373 | ✗ | return -1; | |
| 374 | } | ||
| 375 | |||
| 376 | 12 | int client_index = it->second; | |
| 377 | 12 | auto* client = clients_[client_index].get(); | |
| 378 | |||
| 379 |
2/4✓ Branch 1 taken 12 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 12 times.
✗ Branch 5 not taken.
|
12 | futures.push_back(std::async( |
| 380 | 24 | std::launch::async, [=, &partitioned_keys, &partitioned_values]() { | |
| 381 | 12 | const auto& shard_keys_vec = partitioned_keys[shard_id]; | |
| 382 | 12 | const auto& shard_vals_vec = partitioned_values[shard_id]; | |
| 383 |
2/2✓ Branch 1 taken 14 times.
✓ Branch 2 taken 12 times.
|
26 | for (size_t start = 0; start < shard_keys_vec.size(); |
| 384 | 14 | start += static_cast<size_t>(max_keys_per_request_)) { | |
| 385 | size_t end = | ||
| 386 | 28 | std::min(start + static_cast<size_t>(max_keys_per_request_), | |
| 387 | 14 | shard_keys_vec.size()); | |
| 388 | base::ConstArray<uint64_t> shard_chunk( | ||
| 389 | 14 | shard_keys_vec.data() + start, static_cast<int>(end - start)); | |
| 390 | std::vector<std::vector<float>> value_chunk( | ||
| 391 |
1/2✓ Branch 6 taken 14 times.
✗ Branch 7 not taken.
|
14 | shard_vals_vec.begin() + start, shard_vals_vec.begin() + end); |
| 392 |
2/4✓ Branch 1 taken 14 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
|
14 | if (client->PutParameter(shard_chunk, value_chunk) != 1) { |
| 393 | ✗ | return 0; | |
| 394 | } | ||
| 395 |
1/2✓ Branch 1 taken 14 times.
✗ Branch 2 not taken.
|
14 | } |
| 396 | 12 | return 1; | |
| 397 | })); | ||
| 398 | } | ||
| 399 | |||
| 400 | // 等待所有请求完成 | ||
| 401 |
2/2✓ Branch 5 taken 12 times.
✓ Branch 6 taken 6 times.
|
18 | for (auto& future : futures) { |
| 402 |
2/4✓ Branch 1 taken 12 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
|
12 | if (future.get() != 1) { |
| 403 | ✗ | LOG(ERROR) << "Failed to put parameters to one of the shards"; | |
| 404 | ✗ | return -1; | |
| 405 | } | ||
| 406 | } | ||
| 407 | |||
| 408 | 6 | return 0; | |
| 409 | 6 | } | |
| 410 | |||
| 411 | ✗ | void DistributedBRPCParameterClient::Command(PSCommand command) { | |
| 412 | ✗ | std::vector<std::future<void>> futures; | |
| 413 | |||
| 414 | ✗ | for (auto& client : clients_) { | |
| 415 | ✗ | futures.push_back(std::async(std::launch::async, [&client, command]() { | |
| 416 | ✗ | client->Command(command); | |
| 417 | ✗ | })); | |
| 418 | } | ||
| 419 | |||
| 420 | ✗ | for (auto& future : futures) { | |
| 421 | ✗ | future.wait(); | |
| 422 | } | ||
| 423 | ✗ | } | |
| 424 | |||
| 425 | 10 | bool DistributedBRPCParameterClient::ClearPS() { | |
| 426 | 10 | std::vector<std::future<bool>> futures; | |
| 427 | |||
| 428 |
2/2✓ Branch 4 taken 20 times.
✓ Branch 5 taken 10 times.
|
30 | for (auto& client : clients_) { |
| 429 |
2/4✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 20 times.
✗ Branch 5 not taken.
|
20 | futures.push_back(std::async(std::launch::async, [&client]() { |
| 430 | 20 | return client->ClearPS(); | |
| 431 | })); | ||
| 432 | } | ||
| 433 | |||
| 434 | 10 | bool all_success = true; | |
| 435 |
2/2✓ Branch 5 taken 20 times.
✓ Branch 6 taken 10 times.
|
30 | for (auto& future : futures) { |
| 436 |
2/4✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 20 times.
|
20 | if (!future.get()) { |
| 437 | ✗ | all_success = false; | |
| 438 | } | ||
| 439 | } | ||
| 440 | |||
| 441 | 10 | return all_success; | |
| 442 | 10 | } | |
| 443 | |||
| 444 | 2 | bool DistributedBRPCParameterClient::LoadFakeData(int64_t n) { | |
| 445 | 2 | std::vector<std::future<bool>> futures; | |
| 446 |
2/2✓ Branch 4 taken 4 times.
✓ Branch 5 taken 2 times.
|
6 | for (auto& client : clients_) { |
| 447 | 4 | BRPCParameterClient* raw = client.get(); | |
| 448 |
2/4✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 5 not taken.
|
4 | futures.push_back(std::async(std::launch::async, [raw, n]() { |
| 449 | 4 | return raw->LoadFakeData(n); | |
| 450 | })); | ||
| 451 | } | ||
| 452 | 2 | bool all_success = true; | |
| 453 |
2/2✓ Branch 5 taken 4 times.
✓ Branch 6 taken 2 times.
|
6 | for (auto& future : futures) { |
| 454 |
2/4✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
|
4 | if (!future.get()) { |
| 455 | ✗ | all_success = false; | |
| 456 | } | ||
| 457 | } | ||
| 458 | 2 | return all_success; | |
| 459 | 2 | } | |
| 460 | |||
| 461 | 2 | bool DistributedBRPCParameterClient::DumpFakeData(int64_t n) { | |
| 462 | 2 | std::vector<std::future<bool>> futures; | |
| 463 |
2/2✓ Branch 4 taken 4 times.
✓ Branch 5 taken 2 times.
|
6 | for (auto& client : clients_) { |
| 464 | 4 | BRPCParameterClient* raw = client.get(); | |
| 465 |
2/4✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 5 not taken.
|
4 | futures.push_back(std::async(std::launch::async, [raw, n]() { |
| 466 | 4 | return raw->DumpFakeData(n); | |
| 467 | })); | ||
| 468 | } | ||
| 469 | 2 | bool all_success = true; | |
| 470 |
2/2✓ Branch 5 taken 4 times.
✓ Branch 6 taken 2 times.
|
6 | for (auto& future : futures) { |
| 471 |
2/4✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
|
4 | if (!future.get()) { |
| 472 | ✗ | all_success = false; | |
| 473 | } | ||
| 474 | } | ||
| 475 | 2 | return all_success; | |
| 476 | 2 | } | |
| 477 | ✗ | bool DistributedBRPCParameterClient::LoadCkpt( | |
| 478 | const std::vector<std::string>& model_config_path, | ||
| 479 | const std::vector<std::string>& emb_file_path) { | ||
| 480 | ✗ | std::vector<std::future<bool>> futures; | |
| 481 | |||
| 482 | ✗ | for (auto& client : clients_) { | |
| 483 | ✗ | futures.push_back(std::async( | |
| 484 | ✗ | std::launch::async, [&client, &model_config_path, &emb_file_path]() { | |
| 485 | ✗ | return client->LoadCkpt(model_config_path, emb_file_path); | |
| 486 | })); | ||
| 487 | } | ||
| 488 | |||
| 489 | ✗ | bool all_success = true; | |
| 490 | ✗ | for (auto& future : futures) { | |
| 491 | ✗ | if (!future.get()) { | |
| 492 | ✗ | all_success = false; | |
| 493 | } | ||
| 494 | } | ||
| 495 | |||
| 496 | ✗ | return all_success; | |
| 497 | ✗ | } | |
| 498 | |||
| 499 | ✗ | int DistributedBRPCParameterClient::UpdateParameter( | |
| 500 | const std::string& table_name, | ||
| 501 | const base::ConstArray<uint64_t>& keys, | ||
| 502 | const std::vector<std::vector<float>>* grads) { | ||
| 503 | ✗ | if (grads == nullptr) { | |
| 504 | ✗ | LOG(ERROR) << "UpdateParameter grads pointer is null"; | |
| 505 | ✗ | return -1; | |
| 506 | } | ||
| 507 | ✗ | if (keys.Size() != grads->size()) { | |
| 508 | ✗ | LOG(ERROR) << "UpdateParameter keys/grads size mismatch: " << keys.Size() | |
| 509 | ✗ | << " vs " << grads->size(); | |
| 510 | ✗ | return -1; | |
| 511 | } | ||
| 512 | ✗ | if (keys.Size() == 0) { | |
| 513 | ✗ | return 0; | |
| 514 | } | ||
| 515 | |||
| 516 | ✗ | std::vector<std::vector<uint64_t>> partitioned_keys; | |
| 517 | ✗ | PartitionKeys(keys, partitioned_keys); | |
| 518 | |||
| 519 | ✗ | std::vector<std::vector<std::vector<float>>> partitioned_grads(num_shards_); | |
| 520 | ✗ | for (int shard_id = 0; shard_id < num_shards_; ++shard_id) { | |
| 521 | ✗ | for (size_t i = 0; i < key_index_mapping_[shard_id].size(); ++i) { | |
| 522 | ✗ | size_t original_index = key_index_mapping_[shard_id][i]; | |
| 523 | ✗ | partitioned_grads[shard_id].push_back((*grads)[original_index]); | |
| 524 | } | ||
| 525 | } | ||
| 526 | |||
| 527 | ✗ | std::vector<std::future<int>> futures; | |
| 528 | ✗ | for (int shard_id = 0; shard_id < num_shards_; ++shard_id) { | |
| 529 | ✗ | if (partitioned_keys[shard_id].empty()) { | |
| 530 | ✗ | continue; | |
| 531 | } | ||
| 532 | |||
| 533 | ✗ | auto it = shard_to_client_index_.find(shard_id); | |
| 534 | ✗ | if (it == shard_to_client_index_.end()) { | |
| 535 | ✗ | LOG(ERROR) << "No client found for shard " << shard_id; | |
| 536 | ✗ | return -1; | |
| 537 | } | ||
| 538 | ✗ | int client_index = it->second; | |
| 539 | ✗ | auto* client = clients_[client_index].get(); | |
| 540 | |||
| 541 | ✗ | futures.push_back(std::async( | |
| 542 | ✗ | std::launch::async, [=, &partitioned_keys, &partitioned_grads]() { | |
| 543 | ✗ | const auto& shard_keys_vec = partitioned_keys[shard_id]; | |
| 544 | ✗ | const auto& shard_grads_vec = partitioned_grads[shard_id]; | |
| 545 | ✗ | for (size_t start = 0; start < shard_keys_vec.size(); | |
| 546 | ✗ | start += static_cast<size_t>(max_keys_per_request_)) { | |
| 547 | size_t end = | ||
| 548 | ✗ | std::min(start + static_cast<size_t>(max_keys_per_request_), | |
| 549 | ✗ | shard_keys_vec.size()); | |
| 550 | base::ConstArray<uint64_t> shard_chunk( | ||
| 551 | ✗ | shard_keys_vec.data() + start, static_cast<int>(end - start)); | |
| 552 | std::vector<std::vector<float>> grad_chunk( | ||
| 553 | ✗ | shard_grads_vec.begin() + start, shard_grads_vec.begin() + end); | |
| 554 | ✗ | if (client->UpdateParameter(table_name, shard_chunk, &grad_chunk) != | |
| 555 | 0) { | ||
| 556 | ✗ | return -1; | |
| 557 | } | ||
| 558 | ✗ | } | |
| 559 | ✗ | return 0; | |
| 560 | })); | ||
| 561 | } | ||
| 562 | |||
| 563 | ✗ | for (auto& future : futures) { | |
| 564 | ✗ | if (future.get() != 0) { | |
| 565 | ✗ | LOG(ERROR) << "Failed to update parameters on one of the shards"; | |
| 566 | ✗ | return -1; | |
| 567 | } | ||
| 568 | } | ||
| 569 | |||
| 570 | ✗ | return 0; | |
| 571 | ✗ | } | |
| 572 | |||
| 573 | ✗ | int DistributedBRPCParameterClient::UpdateParameterFlat( | |
| 574 | const std::string& table_name, | ||
| 575 | const base::ConstArray<uint64_t>& keys, | ||
| 576 | const float* grads, | ||
| 577 | int64_t num_rows, | ||
| 578 | int64_t embedding_dim) { | ||
| 579 | ✗ | if (grads == nullptr) { | |
| 580 | ✗ | LOG(ERROR) << "UpdateParameterFlat grads pointer is null"; | |
| 581 | ✗ | return -1; | |
| 582 | } | ||
| 583 | ✗ | if (num_rows < 0 || embedding_dim <= 0) { | |
| 584 | ✗ | LOG(ERROR) << "UpdateParameterFlat invalid shape: rows=" << num_rows | |
| 585 | ✗ | << " dim=" << embedding_dim; | |
| 586 | ✗ | return -1; | |
| 587 | } | ||
| 588 | ✗ | if (keys.Size() != static_cast<size_t>(num_rows)) { | |
| 589 | ✗ | LOG(ERROR) << "UpdateParameterFlat keys/grads size mismatch: " | |
| 590 | ✗ | << keys.Size() << " vs " << num_rows; | |
| 591 | ✗ | return -1; | |
| 592 | } | ||
| 593 | |||
| 594 | ✗ | std::vector<std::vector<float>> row_grads; | |
| 595 | ✗ | row_grads.reserve(static_cast<size_t>(num_rows)); | |
| 596 | ✗ | for (int64_t i = 0; i < num_rows; ++i) { | |
| 597 | ✗ | const float* row = grads + i * embedding_dim; | |
| 598 | ✗ | row_grads.emplace_back(row, row + embedding_dim); | |
| 599 | } | ||
| 600 | ✗ | return UpdateParameter(table_name, keys, &row_grads); | |
| 601 | ✗ | } | |
| 602 | |||
| 603 | ✗ | int DistributedBRPCParameterClient::InitEmbeddingTable( | |
| 604 | const std::string& table_name, | ||
| 605 | const recstore::EmbeddingTableConfig& config) { | ||
| 606 | ✗ | std::vector<std::future<int>> futures; | |
| 607 | ✗ | for (auto& client : clients_) { | |
| 608 | ✗ | futures.push_back( | |
| 609 | ✗ | std::async(std::launch::async, [&client, &table_name, &config]() { | |
| 610 | ✗ | return client->InitEmbeddingTable(table_name, config); | |
| 611 | })); | ||
| 612 | } | ||
| 613 | |||
| 614 | ✗ | for (auto& future : futures) { | |
| 615 | ✗ | if (future.get() != 0) { | |
| 616 | ✗ | LOG(ERROR) << "InitEmbeddingTable failed on one of the shards"; | |
| 617 | ✗ | return -1; | |
| 618 | } | ||
| 619 | } | ||
| 620 | ✗ | return 0; | |
| 621 | ✗ | } | |
| 622 | |||
| 623 | // Prefetch 接口实现 | ||
| 624 | ✗ | uint64_t DistributedBRPCParameterClient::PrefetchParameter( | |
| 625 | const base::ConstArray<uint64_t>& keys) { | ||
| 626 | ✗ | auto cleanup_state = [this](const DistPrefetchState& state) { | |
| 627 | ✗ | for (const auto& shard_state : state.shard_states) { | |
| 628 | ✗ | if (shard_state.client_index < 0 || | |
| 629 | ✗ | shard_state.client_index >= static_cast<int>(clients_.size())) { | |
| 630 | ✗ | continue; | |
| 631 | } | ||
| 632 | ✗ | auto* client = clients_[shard_state.client_index].get(); | |
| 633 | ✗ | for (uint64_t child_prefetch_id : shard_state.child_prefetch_ids) { | |
| 634 | ✗ | client->WaitForPrefetch(child_prefetch_id); | |
| 635 | ✗ | std::vector<std::vector<float>> tmp; | |
| 636 | ✗ | client->GetPrefetchResult(child_prefetch_id, &tmp); | |
| 637 | ✗ | } | |
| 638 | } | ||
| 639 | ✗ | }; | |
| 640 | |||
| 641 | ✗ | if (keys.Size() == 0) { | |
| 642 | ✗ | std::lock_guard<std::mutex> lk(prefetch_mu_); | |
| 643 | ✗ | uint64_t prefetch_id = next_prefetch_id_++; | |
| 644 | ✗ | auto state = std::make_shared<DistPrefetchState>(); | |
| 645 | ✗ | state->total_keys = 0; | |
| 646 | ✗ | prefetch_states_[prefetch_id] = state; | |
| 647 | ✗ | return prefetch_id; | |
| 648 | ✗ | } | |
| 649 | |||
| 650 | ✗ | std::vector<std::vector<uint64_t>> shard_keys(num_shards_); | |
| 651 | ✗ | std::vector<std::vector<size_t>> shard_indices(num_shards_); | |
| 652 | ✗ | for (size_t i = 0; i < keys.Size(); ++i) { | |
| 653 | ✗ | const int shard_id = GetShardId(keys[i]); | |
| 654 | ✗ | shard_keys[shard_id].push_back(keys[i]); | |
| 655 | ✗ | shard_indices[shard_id].push_back(i); | |
| 656 | } | ||
| 657 | |||
| 658 | ✗ | auto state = std::make_shared<DistPrefetchState>(); | |
| 659 | ✗ | state->total_keys = keys.Size(); | |
| 660 | |||
| 661 | ✗ | for (int shard_id = 0; shard_id < num_shards_; ++shard_id) { | |
| 662 | ✗ | if (shard_keys[shard_id].empty()) { | |
| 663 | ✗ | continue; | |
| 664 | } | ||
| 665 | |||
| 666 | ✗ | auto it = shard_to_client_index_.find(shard_id); | |
| 667 | ✗ | if (it == shard_to_client_index_.end()) { | |
| 668 | ✗ | LOG(ERROR) << "No client found for shard " << shard_id; | |
| 669 | ✗ | cleanup_state(*state); | |
| 670 | ✗ | return 0; | |
| 671 | } | ||
| 672 | |||
| 673 | ✗ | DistPrefetchShardState shard_state; | |
| 674 | ✗ | shard_state.shard_id = shard_id; | |
| 675 | ✗ | shard_state.client_index = it->second; | |
| 676 | ✗ | shard_state.original_indices = std::move(shard_indices[shard_id]); | |
| 677 | |||
| 678 | ✗ | const auto& skeys = shard_keys[shard_id]; | |
| 679 | ✗ | for (size_t start = 0; start < skeys.size(); | |
| 680 | ✗ | start += static_cast<size_t>(max_keys_per_request_)) { | |
| 681 | ✗ | size_t end = std::min( | |
| 682 | ✗ | start + static_cast<size_t>(max_keys_per_request_), skeys.size()); | |
| 683 | base::ConstArray<uint64_t> chunk( | ||
| 684 | ✗ | skeys.data() + start, static_cast<int>(end - start)); | |
| 685 | uint64_t child_prefetch_id = | ||
| 686 | ✗ | clients_[shard_state.client_index]->PrefetchParameter(chunk); | |
| 687 | ✗ | if (child_prefetch_id == 0) { | |
| 688 | ✗ | LOG(ERROR) << "PrefetchParameter failed for shard " << shard_id; | |
| 689 | ✗ | cleanup_state(*state); | |
| 690 | ✗ | return 0; | |
| 691 | } | ||
| 692 | ✗ | shard_state.child_prefetch_ids.push_back(child_prefetch_id); | |
| 693 | ✗ | shard_state.chunk_sizes.push_back(static_cast<int>(end - start)); | |
| 694 | } | ||
| 695 | ✗ | state->shard_states.push_back(std::move(shard_state)); | |
| 696 | ✗ | } | |
| 697 | |||
| 698 | ✗ | std::lock_guard<std::mutex> lk(prefetch_mu_); | |
| 699 | ✗ | uint64_t prefetch_id = next_prefetch_id_++; | |
| 700 | ✗ | prefetch_states_[prefetch_id] = std::move(state); | |
| 701 | ✗ | return prefetch_id; | |
| 702 | ✗ | } | |
| 703 | |||
| 704 | ✗ | bool DistributedBRPCParameterClient::IsPrefetchDone(uint64_t prefetch_id) { | |
| 705 | ✗ | std::shared_ptr<DistPrefetchState> state; | |
| 706 | { | ||
| 707 | ✗ | std::lock_guard<std::mutex> lk(prefetch_mu_); | |
| 708 | ✗ | auto it = prefetch_states_.find(prefetch_id); | |
| 709 | ✗ | if (it == prefetch_states_.end()) { | |
| 710 | ✗ | LOG(ERROR) << "Invalid prefetch_id: " << prefetch_id; | |
| 711 | ✗ | return false; | |
| 712 | } | ||
| 713 | ✗ | state = it->second; | |
| 714 | ✗ | } | |
| 715 | |||
| 716 | ✗ | for (const auto& shard_state : state->shard_states) { | |
| 717 | ✗ | auto* client = clients_[shard_state.client_index].get(); | |
| 718 | ✗ | for (uint64_t child_prefetch_id : shard_state.child_prefetch_ids) { | |
| 719 | ✗ | if (!client->IsPrefetchDone(child_prefetch_id)) { | |
| 720 | ✗ | return false; | |
| 721 | } | ||
| 722 | } | ||
| 723 | } | ||
| 724 | ✗ | return true; | |
| 725 | ✗ | } | |
| 726 | |||
| 727 | ✗ | void DistributedBRPCParameterClient::WaitForPrefetch(uint64_t prefetch_id) { | |
| 728 | ✗ | std::shared_ptr<DistPrefetchState> state; | |
| 729 | { | ||
| 730 | ✗ | std::lock_guard<std::mutex> lk(prefetch_mu_); | |
| 731 | ✗ | auto it = prefetch_states_.find(prefetch_id); | |
| 732 | ✗ | if (it == prefetch_states_.end()) { | |
| 733 | ✗ | LOG(ERROR) << "Invalid prefetch_id: " << prefetch_id; | |
| 734 | ✗ | return; | |
| 735 | } | ||
| 736 | ✗ | state = it->second; | |
| 737 | ✗ | } | |
| 738 | |||
| 739 | ✗ | for (const auto& shard_state : state->shard_states) { | |
| 740 | ✗ | auto* client = clients_[shard_state.client_index].get(); | |
| 741 | ✗ | for (uint64_t child_prefetch_id : shard_state.child_prefetch_ids) { | |
| 742 | ✗ | client->WaitForPrefetch(child_prefetch_id); | |
| 743 | } | ||
| 744 | } | ||
| 745 | ✗ | } | |
| 746 | |||
| 747 | ✗ | bool DistributedBRPCParameterClient::GetPrefetchResult( | |
| 748 | uint64_t prefetch_id, std::vector<std::vector<float>>* values) { | ||
| 749 | ✗ | if (values == nullptr) { | |
| 750 | ✗ | LOG(ERROR) << "GetPrefetchResult output pointer is null"; | |
| 751 | ✗ | return false; | |
| 752 | } | ||
| 753 | |||
| 754 | ✗ | std::shared_ptr<DistPrefetchState> state; | |
| 755 | { | ||
| 756 | ✗ | std::lock_guard<std::mutex> lk(prefetch_mu_); | |
| 757 | ✗ | auto it = prefetch_states_.find(prefetch_id); | |
| 758 | ✗ | if (it == prefetch_states_.end()) { | |
| 759 | ✗ | LOG(ERROR) << "Invalid prefetch_id: " << prefetch_id; | |
| 760 | ✗ | return false; | |
| 761 | } | ||
| 762 | ✗ | state = it->second; | |
| 763 | ✗ | } | |
| 764 | |||
| 765 | // Ensure all child RPCs are completed before consuming payloads. | ||
| 766 | ✗ | WaitForPrefetch(prefetch_id); | |
| 767 | |||
| 768 | ✗ | values->clear(); | |
| 769 | ✗ | values->resize(state->total_keys); | |
| 770 | |||
| 771 | ✗ | bool ok_all = true; | |
| 772 | ✗ | for (const auto& shard_state : state->shard_states) { | |
| 773 | ✗ | auto* client = clients_[shard_state.client_index].get(); | |
| 774 | ✗ | size_t shard_offset = 0; | |
| 775 | ✗ | for (size_t i = 0; i < shard_state.child_prefetch_ids.size(); ++i) { | |
| 776 | ✗ | std::vector<std::vector<float>> chunk_values; | |
| 777 | ✗ | if (!client->GetPrefetchResult( | |
| 778 | ✗ | shard_state.child_prefetch_ids[i], &chunk_values)) { | |
| 779 | ✗ | ok_all = false; | |
| 780 | ✗ | break; | |
| 781 | } | ||
| 782 | const int expected = | ||
| 783 | ✗ | (i < shard_state.chunk_sizes.size() | |
| 784 | ✗ | ? shard_state.chunk_sizes[i] | |
| 785 | ✗ | : -1); | |
| 786 | ✗ | if (expected >= 0 && static_cast<int>(chunk_values.size()) != expected) { | |
| 787 | ✗ | LOG(ERROR) << "Prefetch chunk size mismatch: got " | |
| 788 | ✗ | << chunk_values.size() << ", expected " << expected; | |
| 789 | ✗ | ok_all = false; | |
| 790 | ✗ | break; | |
| 791 | } | ||
| 792 | ✗ | for (const auto& row : chunk_values) { | |
| 793 | ✗ | if (shard_offset >= shard_state.original_indices.size()) { | |
| 794 | ✗ | LOG(ERROR) << "Prefetch result overflow in shard " | |
| 795 | ✗ | << shard_state.shard_id; | |
| 796 | ✗ | ok_all = false; | |
| 797 | ✗ | break; | |
| 798 | } | ||
| 799 | ✗ | (*values)[shard_state.original_indices[shard_offset++]] = row; | |
| 800 | } | ||
| 801 | ✗ | if (!ok_all) { | |
| 802 | ✗ | break; | |
| 803 | } | ||
| 804 | ✗ | } | |
| 805 | ✗ | if (!ok_all) { | |
| 806 | ✗ | break; | |
| 807 | } | ||
| 808 | } | ||
| 809 | |||
| 810 | { | ||
| 811 | ✗ | std::lock_guard<std::mutex> lk(prefetch_mu_); | |
| 812 | ✗ | prefetch_states_.erase(prefetch_id); | |
| 813 | ✗ | } | |
| 814 | ✗ | return ok_all; | |
| 815 | ✗ | } | |
| 816 | |||
| 817 | ✗ | bool DistributedBRPCParameterClient::GetPrefetchResultFlat( | |
| 818 | uint64_t prefetch_id, | ||
| 819 | std::vector<float>* values, | ||
| 820 | int64_t* num_rows, | ||
| 821 | int64_t embedding_dim) { | ||
| 822 | ✗ | if (values == nullptr || num_rows == nullptr) { | |
| 823 | ✗ | LOG(ERROR) << "GetPrefetchResultFlat output pointer is null"; | |
| 824 | ✗ | return false; | |
| 825 | } | ||
| 826 | ✗ | if (embedding_dim <= 0) { | |
| 827 | ✗ | LOG(ERROR) << "GetPrefetchResultFlat invalid embedding_dim: " | |
| 828 | ✗ | << embedding_dim; | |
| 829 | ✗ | return false; | |
| 830 | } | ||
| 831 | |||
| 832 | ✗ | std::vector<std::vector<float>> merged_values; | |
| 833 | ✗ | if (!GetPrefetchResult(prefetch_id, &merged_values)) { | |
| 834 | ✗ | return false; | |
| 835 | } | ||
| 836 | |||
| 837 | ✗ | *num_rows = static_cast<int64_t>(merged_values.size()); | |
| 838 | ✗ | values->assign( | |
| 839 | ✗ | static_cast<size_t>(*num_rows) * static_cast<size_t>(embedding_dim), | |
| 840 | ✗ | 0.0f); | |
| 841 | ✗ | for (size_t i = 0; i < merged_values.size(); ++i) { | |
| 842 | ✗ | const auto& row = merged_values[i]; | |
| 843 | ✗ | if (row.empty()) { | |
| 844 | ✗ | continue; | |
| 845 | } | ||
| 846 | const int64_t copy_d = | ||
| 847 | ✗ | std::min<int64_t>(embedding_dim, static_cast<int64_t>(row.size())); | |
| 848 | ✗ | std::memcpy(values->data() + i * static_cast<size_t>(embedding_dim), | |
| 849 | ✗ | row.data(), | |
| 850 | ✗ | static_cast<size_t>(copy_d) * sizeof(float)); | |
| 851 | } | ||
| 852 | ✗ | return true; | |
| 853 | ✗ | } | |
| 854 | |||
| 855 | } // namespace recstore | ||
| 856 |