ps/brpc/brpc_ps_client.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #include "brpc_ps_client.h" | ||
| 2 | |||
| 3 | #include <brpc/channel.h> | ||
| 4 | #include <fmt/core.h> | ||
| 5 | |||
| 6 | #include <cstdint> | ||
| 7 | #include <cstring> | ||
| 8 | #include <future> | ||
| 9 | #include <string> | ||
| 10 | #include <vector> | ||
| 11 | |||
| 12 | #include "base/array.h" | ||
| 13 | #include "base/factory.h" | ||
| 14 | #include "base/flatc.h" | ||
| 15 | #include "base/log.h" | ||
| 16 | #include "base/timer.h" | ||
| 17 | #include "ps/base/parameters.h" | ||
| 18 | #include "ps_brpc.pb.h" | ||
| 19 | |||
| 20 | #ifdef ENABLE_PERF_REPORT | ||
| 21 | # include <chrono> | ||
| 22 | # include "base/report/report_client.h" | ||
| 23 | #endif | ||
| 24 | |||
| 25 | using recstoreps_brpc::CommandRequest; | ||
| 26 | using recstoreps_brpc::CommandResponse; | ||
| 27 | using recstoreps_brpc::GetParameterRequest; | ||
| 28 | using recstoreps_brpc::GetParameterResponse; | ||
| 29 | using recstoreps_brpc::InitEmbeddingTableRequest; | ||
| 30 | using recstoreps_brpc::InitEmbeddingTableResponse; | ||
| 31 | using recstoreps_brpc::PSCommand; | ||
| 32 | using recstoreps_brpc::PutParameterRequest; | ||
| 33 | using recstoreps_brpc::PutParameterResponse; | ||
| 34 | using recstoreps_brpc::UpdateParameterRequest; | ||
| 35 | using recstoreps_brpc::UpdateParameterResponse; | ||
| 36 | |||
| 37 | namespace { | ||
| 38 | |||
| 39 | 50 | const ParameterCompressReader* ExtractGetResponseReader( | |
| 40 | const brpc::Controller& cntl, | ||
| 41 | const GetParameterResponse& response, | ||
| 42 | std::string* payload_storage, | ||
| 43 | int* payload_size) { | ||
| 44 |
1/2✓ Branch 2 taken 50 times.
✗ Branch 3 not taken.
|
50 | if (!cntl.response_attachment().empty()) { |
| 45 | 50 | payload_storage->clear(); | |
| 46 | 50 | cntl.response_attachment().copy_to(payload_storage); | |
| 47 | 50 | *payload_size = payload_storage->size(); | |
| 48 | return reinterpret_cast<const ParameterCompressReader*>( | ||
| 49 | 50 | payload_storage->data()); | |
| 50 | } | ||
| 51 | |||
| 52 | ✗ | *payload_size = response.parameter_value().size(); | |
| 53 | return reinterpret_cast<const ParameterCompressReader*>( | ||
| 54 | ✗ | response.parameter_value().data()); | |
| 55 | } | ||
| 56 | |||
| 57 | } // namespace | ||
| 58 | |||
| 59 | namespace { | ||
| 60 | |||
| 61 | ✗ | int BuildUpdateBlocksFromFlat( | |
| 62 | const base::ConstArray<uint64_t>& keys, | ||
| 63 | const float* grads, | ||
| 64 | int64_t num_rows, | ||
| 65 | int64_t embedding_dim, | ||
| 66 | ParameterCompressor* compressor) { | ||
| 67 | ✗ | if (grads == nullptr) { | |
| 68 | ✗ | LOG(ERROR) << "UpdateParameterFlat grads pointer is null"; | |
| 69 | ✗ | return -1; | |
| 70 | } | ||
| 71 | ✗ | if (num_rows < 0 || embedding_dim <= 0) { | |
| 72 | ✗ | LOG(ERROR) << "UpdateParameterFlat invalid shape: rows=" << num_rows | |
| 73 | ✗ | << " dim=" << embedding_dim; | |
| 74 | ✗ | return -1; | |
| 75 | } | ||
| 76 | ✗ | if (keys.Size() != static_cast<size_t>(num_rows)) { | |
| 77 | ✗ | LOG(ERROR) << "UpdateParameterFlat keys/grads size mismatch: " | |
| 78 | ✗ | << keys.Size() << " vs " << num_rows; | |
| 79 | ✗ | return -1; | |
| 80 | } | ||
| 81 | |||
| 82 | ✗ | for (int64_t i = 0; i < num_rows; ++i) { | |
| 83 | ✗ | ParameterPack pack; | |
| 84 | ✗ | pack.key = keys[static_cast<size_t>(i)]; | |
| 85 | ✗ | pack.dim = embedding_dim; | |
| 86 | ✗ | pack.emb_data = grads + i * embedding_dim; | |
| 87 | ✗ | compressor->AddItem(pack, nullptr); | |
| 88 | } | ||
| 89 | ✗ | return 0; | |
| 90 | } | ||
| 91 | |||
| 92 | } // namespace | ||
| 93 | |||
| 94 | DEFINE_int32(brpc_timeout_ms, 5000, "brpc request timeout in milliseconds"); | ||
| 95 | DEFINE_int32(brpc_max_retry, 3, "brpc max retry times"); | ||
| 96 | DEFINE_bool(parameter_client_random_init_brpc, false, ""); | ||
| 97 | |||
| 98 | // New constructor that takes JSON config | ||
| 99 | 24 | BRPCParameterClient::BRPCParameterClient(json config) | |
| 100 |
1/2✓ Branch 2 taken 24 times.
✗ Branch 3 not taken.
|
24 | : recstore::BasePSClient(config) { |
| 101 |
1/2✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
|
24 | host_ = config.value("host", "localhost"); |
| 102 |
1/2✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
|
24 | port_ = config.value("port", 15000); |
| 103 |
1/2✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
|
24 | shard_ = config.value("shard", 0); |
| 104 |
1/2✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
|
24 | timeout_ms_ = config.value("timeout_ms", FLAGS_brpc_timeout_ms); |
| 105 |
1/2✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
|
24 | max_retry_ = config.value("max_retry", FLAGS_brpc_max_retry); |
| 106 | |||
| 107 |
1/2✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
|
24 | Initialize(); |
| 108 | |||
| 109 | // Initialize bRPC channel | ||
| 110 |
1/2✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
|
24 | channel_ = std::make_shared<brpc::Channel>(); |
| 111 |
1/2✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
|
24 | brpc::ChannelOptions options; |
| 112 | 24 | options.timeout_ms = timeout_ms_; | |
| 113 | 24 | options.max_retry = max_retry_; | |
| 114 | |||
| 115 |
1/2✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
|
24 | std::string server_addr = fmt::format("{}:{}", host_, port_); |
| 116 |
2/4✓ Branch 3 taken 24 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 24 times.
|
24 | if (channel_->Init(server_addr.c_str(), &options) != 0) { |
| 117 | ✗ | LOG(ERROR) << "Failed to initialize bRPC channel to " << server_addr; | |
| 118 | } else { | ||
| 119 |
3/6✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 24 times.
✗ Branch 8 not taken.
|
48 | LOG(INFO) << "Initialized bRPC PS Client Shard " << shard_ << " at " |
| 120 |
3/6✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 24 times.
✗ Branch 8 not taken.
|
24 | << server_addr; |
| 121 | } | ||
| 122 | 24 | } | |
| 123 | |||
| 124 | // Legacy constructor for backward compatibility | ||
| 125 | 6 | BRPCParameterClient::BRPCParameterClient( | |
| 126 | 6 | const std::string& host, int port, int shard) | |
| 127 | : recstore::BasePSClient( | ||
| 128 |
14/28✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 9 taken 6 times.
✗ Branch 10 not taken.
✓ Branch 12 taken 6 times.
✗ Branch 13 not taken.
✓ Branch 16 taken 6 times.
✗ Branch 17 not taken.
✓ Branch 19 taken 6 times.
✗ Branch 20 not taken.
✓ Branch 22 taken 18 times.
✓ Branch 23 taken 6 times.
✓ Branch 25 taken 12 times.
✓ Branch 26 taken 6 times.
✓ Branch 28 taken 12 times.
✓ Branch 29 taken 6 times.
✓ Branch 31 taken 12 times.
✓ Branch 32 taken 6 times.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
|
66 | json{{"host", host}, {"port", port}, {"shard", shard}}), |
| 129 | 6 | host_(host), | |
| 130 | 6 | port_(port), | |
| 131 | 6 | shard_(shard), | |
| 132 | 6 | timeout_ms_(FLAGS_brpc_timeout_ms), | |
| 133 |
2/4✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 6 times.
✗ Branch 6 not taken.
|
18 | max_retry_(FLAGS_brpc_max_retry) { |
| 134 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | Initialize(); |
| 135 | |||
| 136 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | channel_ = std::make_shared<brpc::Channel>(); |
| 137 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | brpc::ChannelOptions options; |
| 138 | 6 | options.timeout_ms = timeout_ms_; | |
| 139 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | options.max_retry = max_retry_; |
| 140 | |||
| 141 | ✗ | std::string server_addr = fmt::format("{}:{}", host, port); | |
| 142 |
2/4✓ Branch 3 taken 6 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 6 times.
|
6 | if (channel_->Init(server_addr.c_str(), &options) != 0) { |
| 143 | ✗ | LOG(ERROR) << "Failed to initialize bRPC channel to " << server_addr; | |
| 144 | } else { | ||
| 145 |
3/6✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 6 times.
✗ Branch 8 not taken.
|
12 | LOG(INFO) << "Initialized bRPC PS Client Shard " << shard_ << " at " |
| 146 |
3/6✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 6 times.
✗ Branch 8 not taken.
|
6 | << server_addr; |
| 147 | } | ||
| 148 | 6 | } | |
| 149 | |||
| 150 | 30 | bool BRPCParameterClient::Initialize() { return true; } | |
| 151 | |||
| 152 | ✗ | int BRPCParameterClient::GetParameter(const base::ConstArray<uint64_t>& keys, | |
| 153 | float* values) { | ||
| 154 | #ifdef ENABLE_PERF_REPORT | ||
| 155 | auto start_time = std::chrono::high_resolution_clock::now(); | ||
| 156 | #endif | ||
| 157 | |||
| 158 | ✗ | if (FLAGS_parameter_client_random_init_brpc) { | |
| 159 | ✗ | CHECK(0) << "todo implement"; | |
| 160 | return true; | ||
| 161 | } | ||
| 162 | |||
| 163 | int request_num = | ||
| 164 | ✗ | (keys.Size() + MAX_PARAMETER_BATCH_BRPC - 1) / MAX_PARAMETER_BATCH_BRPC; | |
| 165 | ✗ | std::vector<GetParameterRequest> requests(request_num); | |
| 166 | ✗ | std::vector<GetParameterResponse> responses(request_num); | |
| 167 | ✗ | std::vector<brpc::Controller> controllers(request_num); | |
| 168 | ✗ | std::vector<int> key_sizes; | |
| 169 | |||
| 170 | // Create stub | ||
| 171 | ✗ | recstoreps_brpc::ParameterService_Stub stub(channel_.get()); | |
| 172 | |||
| 173 | #ifdef ENABLE_PERF_REPORT | ||
| 174 | auto wait_start_time = std::chrono::high_resolution_clock::now(); | ||
| 175 | #endif | ||
| 176 | |||
| 177 | // Send async RPC requests | ||
| 178 | ✗ | for (int start = 0, index = 0; start < keys.Size(); | |
| 179 | ✗ | start += MAX_PARAMETER_BATCH_BRPC, ++index) { | |
| 180 | int key_size = | ||
| 181 | ✗ | std::min((int)(keys.Size() - start), MAX_PARAMETER_BATCH_BRPC); | |
| 182 | ✗ | key_sizes.push_back(key_size); | |
| 183 | |||
| 184 | ✗ | controllers[index].request_attachment().append( | |
| 185 | ✗ | reinterpret_cast<const char*>(&keys[start]), | |
| 186 | ✗ | sizeof(uint64_t) * key_size); | |
| 187 | |||
| 188 | ✗ | google::protobuf::Closure* done = brpc::NewCallback([]() { /* no-op */ }); | |
| 189 | ✗ | stub.GetParameter( | |
| 190 | ✗ | &controllers[index], &requests[index], &responses[index], done); | |
| 191 | } | ||
| 192 | |||
| 193 | // Wait for all RPCs to complete | ||
| 194 | ✗ | for (int i = 0; i < request_num; ++i) { | |
| 195 | ✗ | brpc::Join(controllers[i].call_id()); | |
| 196 | ✗ | if (controllers[i].Failed()) { | |
| 197 | ✗ | LOG(ERROR) << "bRPC GetParameter failed: " << controllers[i].ErrorText(); | |
| 198 | ✗ | return false; | |
| 199 | } | ||
| 200 | } | ||
| 201 | |||
| 202 | #ifdef ENABLE_PERF_REPORT | ||
| 203 | auto wait_end_time = std::chrono::high_resolution_clock::now(); | ||
| 204 | auto wait_duration = | ||
| 205 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 206 | wait_end_time - wait_start_time) | ||
| 207 | .count(); | ||
| 208 | double wait_start_us = | ||
| 209 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 210 | wait_start_time.time_since_epoch()) | ||
| 211 | .count(); | ||
| 212 | std::string wait_label = | ||
| 213 | "brpc_client::RPC_Call_And_Wait_Shard" + std::to_string(shard_); | ||
| 214 | FlameGraphData wait_fg = { | ||
| 215 | wait_label, | ||
| 216 | wait_start_us, | ||
| 217 | 2, // level | ||
| 218 | static_cast<double>(wait_duration), | ||
| 219 | static_cast<double>(wait_duration)}; | ||
| 220 | std::string unique_id = | ||
| 221 | "embread_debug|" + std::to_string(static_cast<uint64_t>(wait_start_us)); | ||
| 222 | report_flame_graph("emb_read_flame_map", unique_id.c_str(), wait_fg); | ||
| 223 | |||
| 224 | double start_us_for_rpc = | ||
| 225 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 226 | start_time.time_since_epoch()) | ||
| 227 | .count(); | ||
| 228 | std::string report_id_for_rpc = | ||
| 229 | "brpc_client::GetParameter|" + | ||
| 230 | std::to_string(static_cast<uint64_t>(start_us_for_rpc)); | ||
| 231 | report("embread_stages", | ||
| 232 | report_id_for_rpc.c_str(), | ||
| 233 | "rpc_duration_us", | ||
| 234 | static_cast<double>(wait_duration)); | ||
| 235 | |||
| 236 | auto deserialize_start_time = std::chrono::high_resolution_clock::now(); | ||
| 237 | #endif | ||
| 238 | |||
| 239 | // Parse responses | ||
| 240 | ✗ | size_t get_embedding_acc = 0; | |
| 241 | ✗ | int old_dimension = -1; | |
| 242 | ✗ | std::string payload_storage; | |
| 243 | |||
| 244 | ✗ | for (int i = 0; i < responses.size(); ++i) { | |
| 245 | ✗ | auto& response = responses[i]; | |
| 246 | ✗ | int key_size = key_sizes[i]; | |
| 247 | ✗ | int payload_size = 0; | |
| 248 | ✗ | auto parameters = ExtractGetResponseReader( | |
| 249 | ✗ | controllers[i], response, &payload_storage, &payload_size); | |
| 250 | |||
| 251 | ✗ | if (parameters == nullptr || !parameters->Valid(payload_size)) { | |
| 252 | ✗ | LOG(ERROR) << "GetParameter invalid payload: " << payload_size; | |
| 253 | ✗ | return false; | |
| 254 | } | ||
| 255 | |||
| 256 | ✗ | if (parameters->size != key_size) { | |
| 257 | ✗ | LOG(ERROR) << "GetParameter error: " << parameters->size << " vs " | |
| 258 | ✗ | << key_size; | |
| 259 | ✗ | return false; | |
| 260 | } | ||
| 261 | |||
| 262 | ✗ | for (int index = 0; index < parameters->item_size(); ++index) { | |
| 263 | ✗ | auto item = parameters->item(index); | |
| 264 | ✗ | if (item->dim != 0) { | |
| 265 | ✗ | if (old_dimension == -1) | |
| 266 | ✗ | old_dimension = item->dim; | |
| 267 | ✗ | CHECK_EQ(item->dim, old_dimension); | |
| 268 | ✗ | std::copy_n( | |
| 269 | ✗ | item->embedding, item->dim, values + item->dim * get_embedding_acc); | |
| 270 | } else { | ||
| 271 | ✗ | RECSTORE_LOG_EVERY_MS(ERROR, 2000) | |
| 272 | ✗ | << "error; not find key " << keys[get_embedding_acc] << " in ps"; | |
| 273 | } | ||
| 274 | ✗ | get_embedding_acc++; | |
| 275 | } | ||
| 276 | } | ||
| 277 | |||
| 278 | #ifdef ENABLE_PERF_REPORT | ||
| 279 | auto deserialize_end_time = std::chrono::high_resolution_clock::now(); | ||
| 280 | auto deserialize_duration = | ||
| 281 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 282 | deserialize_end_time - deserialize_start_time) | ||
| 283 | .count(); | ||
| 284 | double deserialize_start_us = | ||
| 285 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 286 | deserialize_start_time.time_since_epoch()) | ||
| 287 | .count(); | ||
| 288 | std::string des_label = | ||
| 289 | "brpc_client::Deserialize_Shard" + std::to_string(shard_); | ||
| 290 | FlameGraphData des_fg = { | ||
| 291 | des_label, | ||
| 292 | deserialize_start_us, | ||
| 293 | 2, // level | ||
| 294 | static_cast<double>(deserialize_duration), | ||
| 295 | static_cast<double>(deserialize_duration)}; | ||
| 296 | std::string des_unique_id = | ||
| 297 | "embread_debug|" + | ||
| 298 | std::to_string(static_cast<uint64_t>(deserialize_start_us)); | ||
| 299 | report_flame_graph("emb_read_flame_map", des_unique_id.c_str(), des_fg); | ||
| 300 | |||
| 301 | double start_us_for_des = | ||
| 302 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 303 | start_time.time_since_epoch()) | ||
| 304 | .count(); | ||
| 305 | std::string report_id_for_des = | ||
| 306 | "brpc_client::GetParameter|" + | ||
| 307 | std::to_string(static_cast<uint64_t>(start_us_for_des)); | ||
| 308 | report("embread_stages", | ||
| 309 | report_id_for_des.c_str(), | ||
| 310 | "deserialize_duration_us", | ||
| 311 | static_cast<double>(deserialize_duration)); | ||
| 312 | #endif | ||
| 313 | |||
| 314 | #ifdef ENABLE_PERF_REPORT | ||
| 315 | auto end_time = std::chrono::high_resolution_clock::now(); | ||
| 316 | auto duration = | ||
| 317 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 318 | end_time - start_time) | ||
| 319 | .count(); | ||
| 320 | report("ps_client_latency", | ||
| 321 | "GetParameter", | ||
| 322 | "latency_us", | ||
| 323 | static_cast<double>(duration)); | ||
| 324 | |||
| 325 | double start_us = | ||
| 326 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 327 | start_time.time_since_epoch()) | ||
| 328 | .count(); | ||
| 329 | FlameGraphData fg_data = { | ||
| 330 | "brpc_client::GetParameter", | ||
| 331 | start_us, | ||
| 332 | 1, // level | ||
| 333 | static_cast<double>(duration), | ||
| 334 | static_cast<double>(duration)}; | ||
| 335 | |||
| 336 | std::string report_id = "brpc_client::GetParameter|" + | ||
| 337 | std::to_string(static_cast<uint64_t>(start_us)); | ||
| 338 | |||
| 339 | report("embread_stages", | ||
| 340 | report_id.c_str(), | ||
| 341 | "duration_us", | ||
| 342 | static_cast<double>(duration)); | ||
| 343 | |||
| 344 | report("embread_stages", | ||
| 345 | report_id.c_str(), | ||
| 346 | "request_size", | ||
| 347 | static_cast<double>(keys.Size())); | ||
| 348 | |||
| 349 | std::string final_unique_id = | ||
| 350 | "embread_debug|" + std::to_string(static_cast<uint64_t>(start_us)); | ||
| 351 | report_flame_graph("emb_read_flame_map", final_unique_id.c_str(), fg_data); | ||
| 352 | #endif | ||
| 353 | |||
| 354 | ✗ | return true; | |
| 355 | ✗ | } | |
| 356 | |||
| 357 | 40 | int BRPCParameterClient::GetParameter(const base::ConstArray<uint64_t>& keys, | |
| 358 | std::vector<std::vector<float>>* values) { | ||
| 359 | #ifdef ENABLE_PERF_REPORT | ||
| 360 | auto start_time = std::chrono::high_resolution_clock::now(); | ||
| 361 | #endif | ||
| 362 | |||
| 363 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 40 times.
|
40 | if (FLAGS_parameter_client_random_init_brpc) { |
| 364 | ✗ | values->clear(); | |
| 365 | ✗ | values->reserve(keys.Size()); | |
| 366 | ✗ | for (size_t i = 0; i < keys.Size(); i++) | |
| 367 | ✗ | values->emplace_back(std::vector<float>(128, 0.1)); | |
| 368 | ✗ | return true; | |
| 369 | } | ||
| 370 | |||
| 371 | 40 | values->clear(); | |
| 372 |
1/2✓ Branch 2 taken 40 times.
✗ Branch 3 not taken.
|
40 | values->reserve(keys.Size()); |
| 373 | |||
| 374 | int request_num = | ||
| 375 | 40 | (keys.Size() + MAX_PARAMETER_BATCH_BRPC - 1) / MAX_PARAMETER_BATCH_BRPC; | |
| 376 |
1/2✓ Branch 2 taken 40 times.
✗ Branch 3 not taken.
|
40 | std::vector<GetParameterRequest> requests(request_num); |
| 377 |
1/2✓ Branch 2 taken 40 times.
✗ Branch 3 not taken.
|
40 | std::vector<GetParameterResponse> responses(request_num); |
| 378 |
1/2✓ Branch 2 taken 40 times.
✗ Branch 3 not taken.
|
40 | std::vector<brpc::Controller> controllers(request_num); |
| 379 | 40 | std::vector<int> key_sizes; | |
| 380 | |||
| 381 |
1/2✓ Branch 2 taken 40 times.
✗ Branch 3 not taken.
|
40 | recstoreps_brpc::ParameterService_Stub stub(channel_.get()); |
| 382 | |||
| 383 | #ifdef ENABLE_PERF_REPORT | ||
| 384 | auto wait_start_time = std::chrono::high_resolution_clock::now(); | ||
| 385 | #endif | ||
| 386 | |||
| 387 | // Send async RPC requests | ||
| 388 |
2/2✓ Branch 1 taken 40 times.
✓ Branch 2 taken 40 times.
|
80 | for (int start = 0, index = 0; start < keys.Size(); |
| 389 | 40 | start += MAX_PARAMETER_BATCH_BRPC, ++index) { | |
| 390 | int key_size = | ||
| 391 | 40 | std::min((int)(keys.Size() - start), MAX_PARAMETER_BATCH_BRPC); | |
| 392 |
1/2✓ Branch 1 taken 40 times.
✗ Branch 2 not taken.
|
40 | key_sizes.push_back(key_size); |
| 393 | |||
| 394 |
1/2✓ Branch 3 taken 40 times.
✗ Branch 4 not taken.
|
80 | controllers[index].request_attachment().append( |
| 395 | 40 | reinterpret_cast<const char*>(&keys[start]), | |
| 396 | 40 | sizeof(uint64_t) * key_size); | |
| 397 | |||
| 398 |
1/2✓ Branch 2 taken 40 times.
✗ Branch 3 not taken.
|
40 | google::protobuf::Closure* done = brpc::NewCallback([]() { /* no-op */ }); |
| 399 |
1/2✓ Branch 1 taken 40 times.
✗ Branch 2 not taken.
|
40 | stub.GetParameter( |
| 400 | 40 | &controllers[index], &requests[index], &responses[index], done); | |
| 401 | } | ||
| 402 | |||
| 403 | // Wait for all RPCs to complete | ||
| 404 |
2/2✓ Branch 0 taken 40 times.
✓ Branch 1 taken 40 times.
|
80 | for (int i = 0; i < request_num; ++i) { |
| 405 |
2/4✓ Branch 2 taken 40 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 40 times.
✗ Branch 6 not taken.
|
40 | brpc::Join(controllers[i].call_id()); |
| 406 |
2/4✓ Branch 2 taken 40 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 40 times.
|
40 | if (controllers[i].Failed()) { |
| 407 | ✗ | LOG(ERROR) << "bRPC GetParameter failed: " << controllers[i].ErrorText(); | |
| 408 | ✗ | return false; | |
| 409 | } | ||
| 410 | } | ||
| 411 | |||
| 412 | #ifdef ENABLE_PERF_REPORT | ||
| 413 | auto wait_end_time = std::chrono::high_resolution_clock::now(); | ||
| 414 | auto wait_duration = | ||
| 415 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 416 | wait_end_time - wait_start_time) | ||
| 417 | .count(); | ||
| 418 | double wait_start_us = | ||
| 419 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 420 | wait_start_time.time_since_epoch()) | ||
| 421 | .count(); | ||
| 422 | std::string wait_label = | ||
| 423 | "brpc_client::RPC_Call_And_Wait_Shard" + std::to_string(shard_); | ||
| 424 | FlameGraphData wait_fg = { | ||
| 425 | wait_label, | ||
| 426 | wait_start_us, | ||
| 427 | 2, // level | ||
| 428 | static_cast<double>(wait_duration), | ||
| 429 | static_cast<double>(wait_duration)}; | ||
| 430 | std::string unique_id = | ||
| 431 | "embread_debug|" + std::to_string(static_cast<uint64_t>(wait_start_us)); | ||
| 432 | report_flame_graph("emb_read_flame_map", unique_id.c_str(), wait_fg); | ||
| 433 | |||
| 434 | double start_us_for_rpc = | ||
| 435 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 436 | start_time.time_since_epoch()) | ||
| 437 | .count(); | ||
| 438 | std::string report_id_for_rpc = | ||
| 439 | "brpc_client::GetParameter_Vec|" + | ||
| 440 | std::to_string(static_cast<uint64_t>(start_us_for_rpc)); | ||
| 441 | |||
| 442 | report("embread_stages", | ||
| 443 | report_id_for_rpc.c_str(), | ||
| 444 | "rpc_duration_us", | ||
| 445 | static_cast<double>(wait_duration)); | ||
| 446 | |||
| 447 | auto deserialize_start_time = std::chrono::high_resolution_clock::now(); | ||
| 448 | #endif | ||
| 449 | |||
| 450 | // Parse responses | ||
| 451 | 40 | std::string payload_storage; | |
| 452 |
2/2✓ Branch 1 taken 40 times.
✓ Branch 2 taken 40 times.
|
80 | for (int i = 0; i < responses.size(); ++i) { |
| 453 | 40 | auto& response = responses[i]; | |
| 454 | 40 | int key_size = key_sizes[i]; | |
| 455 | 40 | int payload_size = 0; | |
| 456 |
1/2✓ Branch 1 taken 40 times.
✗ Branch 2 not taken.
|
40 | auto parameters = ExtractGetResponseReader( |
| 457 | 40 | controllers[i], response, &payload_storage, &payload_size); | |
| 458 | |||
| 459 |
4/8✓ Branch 0 taken 40 times.
✗ Branch 1 not taken.
✓ Branch 3 taken 40 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 40 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 40 times.
|
40 | if (parameters == nullptr || !parameters->Valid(payload_size)) { |
| 460 | ✗ | LOG(ERROR) << "GetParameter(vector) invalid payload: " << payload_size; | |
| 461 | ✗ | return false; | |
| 462 | } | ||
| 463 | |||
| 464 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 40 times.
|
40 | if (unlikely(parameters->size != key_size)) { |
| 465 | ✗ | LOG(ERROR) << "GetParameter error: " << parameters->size << " vs " | |
| 466 | ✗ | << key_size; | |
| 467 | ✗ | return false; | |
| 468 | } | ||
| 469 | |||
| 470 |
2/2✓ Branch 1 taken 266 times.
✓ Branch 2 taken 40 times.
|
306 | for (int index = 0; index < parameters->item_size(); ++index) { |
| 471 |
1/2✓ Branch 1 taken 266 times.
✗ Branch 2 not taken.
|
266 | auto item = parameters->item(index); |
| 472 |
2/2✓ Branch 0 taken 254 times.
✓ Branch 1 taken 12 times.
|
266 | if (item->dim != 0) { |
| 473 |
1/2✓ Branch 1 taken 254 times.
✗ Branch 2 not taken.
|
254 | values->emplace_back( |
| 474 |
1/2✓ Branch 2 taken 254 times.
✗ Branch 3 not taken.
|
508 | std::vector<float>(item->embedding, item->embedding + item->dim)); |
| 475 | } else { | ||
| 476 |
2/4✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 12 times.
✗ Branch 6 not taken.
|
12 | values->emplace_back(std::vector<float>(0)); |
| 477 | } | ||
| 478 | } | ||
| 479 | } | ||
| 480 | |||
| 481 | #ifdef ENABLE_PERF_REPORT | ||
| 482 | auto deserialize_end_time = std::chrono::high_resolution_clock::now(); | ||
| 483 | auto deserialize_duration = | ||
| 484 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 485 | deserialize_end_time - deserialize_start_time) | ||
| 486 | .count(); | ||
| 487 | double deserialize_start_us = | ||
| 488 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 489 | deserialize_start_time.time_since_epoch()) | ||
| 490 | .count(); | ||
| 491 | std::string des_label = | ||
| 492 | "brpc_client::Deserialize_Shard" + std::to_string(shard_); | ||
| 493 | FlameGraphData des_fg = { | ||
| 494 | des_label, | ||
| 495 | deserialize_start_us, | ||
| 496 | 2, // level | ||
| 497 | static_cast<double>(deserialize_duration), | ||
| 498 | static_cast<double>(deserialize_duration)}; | ||
| 499 | std::string des_unique_id = | ||
| 500 | "embread_debug|" + | ||
| 501 | std::to_string(static_cast<uint64_t>(deserialize_start_us)); | ||
| 502 | report_flame_graph("emb_read_flame_map", des_unique_id.c_str(), des_fg); | ||
| 503 | |||
| 504 | double start_us_for_des = | ||
| 505 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 506 | start_time.time_since_epoch()) | ||
| 507 | .count(); | ||
| 508 | std::string report_id_for_des = | ||
| 509 | "brpc_client::GetParameter_Vec|" + | ||
| 510 | std::to_string(static_cast<uint64_t>(start_us_for_des)); | ||
| 511 | report("embread_stages", | ||
| 512 | report_id_for_des.c_str(), | ||
| 513 | "deserialize_duration_us", | ||
| 514 | static_cast<double>(deserialize_duration)); | ||
| 515 | #endif | ||
| 516 | |||
| 517 | #ifdef ENABLE_PERF_REPORT | ||
| 518 | auto end_time = std::chrono::high_resolution_clock::now(); | ||
| 519 | auto duration = | ||
| 520 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 521 | end_time - start_time) | ||
| 522 | .count(); | ||
| 523 | report("ps_client_latency", | ||
| 524 | "GetParameter", | ||
| 525 | "latency_us", | ||
| 526 | static_cast<double>(duration)); | ||
| 527 | |||
| 528 | double start_us = | ||
| 529 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 530 | start_time.time_since_epoch()) | ||
| 531 | .count(); | ||
| 532 | FlameGraphData fg_data = { | ||
| 533 | "brpc_client::GetParameter_Vec", | ||
| 534 | start_us, | ||
| 535 | 1, // level | ||
| 536 | static_cast<double>(duration), | ||
| 537 | static_cast<double>(duration)}; | ||
| 538 | |||
| 539 | std::string report_id = "brpc_client::GetParameter_Vec|" + | ||
| 540 | std::to_string(static_cast<uint64_t>(start_us)); | ||
| 541 | |||
| 542 | report("embread_stages", | ||
| 543 | report_id.c_str(), | ||
| 544 | "duration_us", | ||
| 545 | static_cast<double>(duration)); | ||
| 546 | |||
| 547 | report("embread_stages", | ||
| 548 | report_id.c_str(), | ||
| 549 | "request_size", | ||
| 550 | static_cast<double>(keys.Size())); | ||
| 551 | |||
| 552 | std::string final_unique_id = | ||
| 553 | "embread_debug|" + std::to_string(static_cast<uint64_t>(start_us)); | ||
| 554 | report_flame_graph("emb_read_flame_map", final_unique_id.c_str(), fg_data); | ||
| 555 | #endif | ||
| 556 | |||
| 557 | 40 | return true; | |
| 558 | 40 | } | |
| 559 | |||
| 560 | 10 | static void OnPrefetchDone(BrpcPrefetchBatch* batch) { | |
| 561 | 10 | batch->completed_count_++; | |
| 562 | 10 | } | |
| 563 | |||
| 564 | 8 | static void OnPrewriteDone(BrpcPrewriteBatch* batch) { | |
| 565 | 8 | batch->completed_count_++; | |
| 566 | 8 | } | |
| 567 | |||
| 568 | uint64_t | ||
| 569 | 10 | BRPCParameterClient::PrefetchParameter(const base::ConstArray<uint64_t>& keys) { | |
| 570 | 10 | uint64_t prefetch_id = next_prefetch_id_++; | |
| 571 | int request_num = | ||
| 572 | 10 | (keys.Size() + MAX_PARAMETER_BATCH_BRPC - 1) / MAX_PARAMETER_BATCH_BRPC; | |
| 573 | |||
| 574 | // Construct in map so batch pointers stay valid | ||
| 575 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | auto it = prefetch_batches_.emplace(prefetch_id, request_num).first; |
| 576 | 10 | struct BrpcPrefetchBatch* pb = &it->second; | |
| 577 | |||
| 578 |
1/2✓ Branch 2 taken 10 times.
✗ Branch 3 not taken.
|
10 | recstoreps_brpc::ParameterService_Stub stub(channel_.get()); |
| 579 | |||
| 580 |
2/2✓ Branch 1 taken 10 times.
✓ Branch 2 taken 10 times.
|
20 | for (int start = 0, index = 0; start < keys.Size(); |
| 581 | 10 | start += MAX_PARAMETER_BATCH_BRPC, ++index) { | |
| 582 | int key_size = | ||
| 583 | 10 | std::min((int)(keys.Size() - start), MAX_PARAMETER_BATCH_BRPC); | |
| 584 | 10 | pb->key_sizes_[index] = key_size; | |
| 585 | |||
| 586 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | GetParameterRequest request; |
| 587 | |||
| 588 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | pb->controllers_[index] = std::make_unique<brpc::Controller>(); |
| 589 |
1/2✓ Branch 4 taken 10 times.
✗ Branch 5 not taken.
|
20 | pb->controllers_[index]->request_attachment().append( |
| 590 | 10 | reinterpret_cast<const char*>(&keys[start]), | |
| 591 | 10 | sizeof(uint64_t) * key_size); | |
| 592 | |||
| 593 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | google::protobuf::Closure* done = brpc::NewCallback(OnPrefetchDone, pb); |
| 594 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | stub.GetParameter( |
| 595 | 10 | pb->controllers_[index].get(), &request, &pb->responses_[index], done); | |
| 596 | 10 | } | |
| 597 | |||
| 598 | 10 | return prefetch_id; | |
| 599 | 10 | } | |
| 600 | |||
| 601 | ✗ | bool BRPCParameterClient::IsPrefetchDone(uint64_t prefetch_id) { | |
| 602 | ✗ | auto it = prefetch_batches_.find(prefetch_id); | |
| 603 | ✗ | if (it == prefetch_batches_.end()) { | |
| 604 | ✗ | LOG(ERROR) << "Invalid prefetch_id: " << prefetch_id; | |
| 605 | ✗ | return false; | |
| 606 | } | ||
| 607 | |||
| 608 | ✗ | auto& pb = it->second; | |
| 609 | |||
| 610 | ✗ | return pb.completed_count_ == pb.batch_size_; | |
| 611 | } | ||
| 612 | |||
| 613 | 10 | void BRPCParameterClient::WaitForPrefetch(uint64_t prefetch_id) { | |
| 614 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | auto it = prefetch_batches_.find(prefetch_id); |
| 615 |
1/2✗ Branch 2 not taken.
✓ Branch 3 taken 10 times.
|
10 | if (it == prefetch_batches_.end()) { |
| 616 | ✗ | LOG(ERROR) << "Invalid prefetch_id: " << prefetch_id; | |
| 617 | ✗ | return; | |
| 618 | } | ||
| 619 | 10 | auto& pb = it->second; | |
| 620 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
|
20 | for (int i = 0; i < pb.batch_size_; ++i) { |
| 621 |
1/2✓ Branch 2 taken 10 times.
✗ Branch 3 not taken.
|
10 | if (pb.controllers_[i]) { |
| 622 |
2/4✓ Branch 3 taken 10 times.
✗ Branch 4 not taken.
✓ Branch 6 taken 10 times.
✗ Branch 7 not taken.
|
10 | brpc::Join(pb.controllers_[i]->call_id()); |
| 623 | } | ||
| 624 | } | ||
| 625 | 10 | pb.completed_count_ = pb.batch_size_; | |
| 626 | } | ||
| 627 | |||
| 628 | 10 | bool BRPCParameterClient::GetPrefetchResult( | |
| 629 | uint64_t prefetch_id, std::vector<std::vector<float>>* values) { | ||
| 630 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | auto it = prefetch_batches_.find(prefetch_id); |
| 631 |
1/2✗ Branch 2 not taken.
✓ Branch 3 taken 10 times.
|
10 | if (it == prefetch_batches_.end()) { |
| 632 | ✗ | LOG(ERROR) << "Invalid prefetch_id: " << prefetch_id; | |
| 633 | ✗ | return false; | |
| 634 | } | ||
| 635 | |||
| 636 | 10 | auto& pb = it->second; | |
| 637 | 10 | int request_num = pb.batch_size_; | |
| 638 | |||
| 639 | 10 | values->clear(); | |
| 640 | 10 | int keys_size = 0; | |
| 641 |
2/2✓ Branch 4 taken 10 times.
✓ Branch 5 taken 10 times.
|
20 | for (const auto& size : pb.key_sizes_) { |
| 642 | 10 | keys_size += size; | |
| 643 | } | ||
| 644 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | values->reserve(keys_size); |
| 645 | |||
| 646 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
|
20 | for (int i = 0; i < request_num; ++i) { |
| 647 |
2/4✓ Branch 3 taken 10 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 10 times.
|
10 | if (pb.controllers_[i]->Failed()) { |
| 648 | ✗ | LOG(ERROR) << "Prefetch request failed: " | |
| 649 | ✗ | << pb.controllers_[i]->ErrorText(); | |
| 650 | ✗ | return false; | |
| 651 | } | ||
| 652 | |||
| 653 | 10 | auto& response = pb.responses_[i]; | |
| 654 | 10 | int key_size = pb.key_sizes_[i]; | |
| 655 | 10 | std::string payload_storage; | |
| 656 | 10 | int payload_size = 0; | |
| 657 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | auto parameters = ExtractGetResponseReader( |
| 658 | 10 | *pb.controllers_[i], response, &payload_storage, &payload_size); | |
| 659 | |||
| 660 |
4/8✓ Branch 0 taken 10 times.
✗ Branch 1 not taken.
✓ Branch 3 taken 10 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 10 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 10 times.
|
10 | if (parameters == nullptr || !parameters->Valid(payload_size)) { |
| 661 | ✗ | LOG(ERROR) << "Prefetch invalid payload: " << payload_size; | |
| 662 | ✗ | return false; | |
| 663 | } | ||
| 664 | |||
| 665 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 10 times.
|
10 | if (unlikely(parameters->size != key_size)) { |
| 666 | ✗ | LOG(ERROR) << "GetParameter error: " << parameters->size << " vs " | |
| 667 | ✗ | << key_size; | |
| 668 | ✗ | return false; | |
| 669 | } | ||
| 670 | |||
| 671 |
2/2✓ Branch 1 taken 106 times.
✓ Branch 2 taken 10 times.
|
116 | for (int index = 0; index < parameters->item_size(); ++index) { |
| 672 |
1/2✓ Branch 1 taken 106 times.
✗ Branch 2 not taken.
|
106 | auto item = parameters->item(index); |
| 673 |
1/2✓ Branch 0 taken 106 times.
✗ Branch 1 not taken.
|
106 | if (item->dim != 0) { |
| 674 |
1/2✓ Branch 1 taken 106 times.
✗ Branch 2 not taken.
|
106 | values->emplace_back( |
| 675 |
1/2✓ Branch 2 taken 106 times.
✗ Branch 3 not taken.
|
212 | std::vector<float>(item->embedding, item->embedding + item->dim)); |
| 676 | } else { | ||
| 677 | ✗ | values->emplace_back(std::vector<float>(0)); | |
| 678 | } | ||
| 679 | } | ||
| 680 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | } |
| 681 | |||
| 682 | // Remove completed batch | ||
| 683 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | prefetch_batches_.erase(it); |
| 684 | |||
| 685 | 10 | return true; | |
| 686 | } | ||
| 687 | |||
| 688 | ✗ | bool BRPCParameterClient::GetPrefetchResultFlat( | |
| 689 | uint64_t prefetch_id, | ||
| 690 | std::vector<float>* values, | ||
| 691 | int64_t* num_rows, | ||
| 692 | int64_t embedding_dim) { | ||
| 693 | ✗ | auto it = prefetch_batches_.find(prefetch_id); | |
| 694 | ✗ | if (it == prefetch_batches_.end()) { | |
| 695 | ✗ | LOG(ERROR) << "Invalid prefetch_id: " << prefetch_id; | |
| 696 | ✗ | return false; | |
| 697 | } | ||
| 698 | ✗ | if (values == nullptr || num_rows == nullptr) { | |
| 699 | ✗ | LOG(ERROR) << "GetPrefetchResultFlat output pointer is null"; | |
| 700 | ✗ | return false; | |
| 701 | } | ||
| 702 | |||
| 703 | ✗ | auto& pb = it->second; | |
| 704 | ✗ | int request_num = pb.batch_size_; | |
| 705 | ✗ | int total_keys = 0; | |
| 706 | ✗ | for (const auto& size : pb.key_sizes_) { | |
| 707 | ✗ | total_keys += size; | |
| 708 | } | ||
| 709 | |||
| 710 | ✗ | *num_rows = static_cast<int64_t>(total_keys); | |
| 711 | ✗ | values->assign( | |
| 712 | ✗ | static_cast<size_t>(*num_rows) * static_cast<size_t>(embedding_dim), | |
| 713 | ✗ | 0.0f); | |
| 714 | |||
| 715 | ✗ | size_t row_offset = 0; | |
| 716 | ✗ | for (int i = 0; i < request_num; ++i) { | |
| 717 | ✗ | if (pb.controllers_[i]->Failed()) { | |
| 718 | ✗ | LOG(ERROR) << "Prefetch request failed: " | |
| 719 | ✗ | << pb.controllers_[i]->ErrorText(); | |
| 720 | ✗ | return false; | |
| 721 | } | ||
| 722 | |||
| 723 | ✗ | auto& response = pb.responses_[i]; | |
| 724 | ✗ | int key_size = pb.key_sizes_[i]; | |
| 725 | ✗ | std::string payload_storage; | |
| 726 | ✗ | int payload_size = 0; | |
| 727 | ✗ | auto parameters = ExtractGetResponseReader( | |
| 728 | ✗ | *pb.controllers_[i], response, &payload_storage, &payload_size); | |
| 729 | |||
| 730 | ✗ | if (parameters == nullptr || !parameters->Valid(payload_size)) { | |
| 731 | ✗ | LOG(ERROR) << "Prefetch invalid payload: " << payload_size; | |
| 732 | ✗ | return false; | |
| 733 | } | ||
| 734 | |||
| 735 | ✗ | if (unlikely(parameters->size != key_size)) { | |
| 736 | ✗ | LOG(ERROR) << "GetParameter error: " << parameters->size << " vs " | |
| 737 | ✗ | << key_size; | |
| 738 | ✗ | return false; | |
| 739 | } | ||
| 740 | |||
| 741 | ✗ | for (int index = 0; index < parameters->item_size(); | |
| 742 | ✗ | ++index, ++row_offset) { | |
| 743 | ✗ | auto item = parameters->item(index); | |
| 744 | ✗ | if (item->dim != 0) { | |
| 745 | const int64_t copy_d = | ||
| 746 | ✗ | std::min<int64_t>(embedding_dim, static_cast<int64_t>(item->dim)); | |
| 747 | ✗ | std::memcpy(values->data() + row_offset * embedding_dim, | |
| 748 | ✗ | item->embedding, | |
| 749 | ✗ | static_cast<size_t>(copy_d) * sizeof(float)); | |
| 750 | } | ||
| 751 | } | ||
| 752 | ✗ | } | |
| 753 | |||
| 754 | ✗ | prefetch_batches_.erase(it); | |
| 755 | ✗ | return true; | |
| 756 | } | ||
| 757 | |||
| 758 | 34 | bool BRPCParameterClient::ClearPS() { | |
| 759 |
1/2✓ Branch 1 taken 34 times.
✗ Branch 2 not taken.
|
34 | CommandRequest request; |
| 760 |
1/2✓ Branch 1 taken 34 times.
✗ Branch 2 not taken.
|
34 | CommandResponse response; |
| 761 |
1/2✓ Branch 1 taken 34 times.
✗ Branch 2 not taken.
|
34 | request.set_command(PSCommand::CLEAR_PS); |
| 762 | |||
| 763 |
1/2✓ Branch 1 taken 34 times.
✗ Branch 2 not taken.
|
34 | brpc::Controller cntl; |
| 764 |
1/2✓ Branch 2 taken 34 times.
✗ Branch 3 not taken.
|
34 | recstoreps_brpc::ParameterService_Stub stub(channel_.get()); |
| 765 |
1/2✓ Branch 1 taken 34 times.
✗ Branch 2 not taken.
|
34 | stub.Command(&cntl, &request, &response, nullptr); |
| 766 | |||
| 767 |
2/4✓ Branch 1 taken 34 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 34 times.
|
34 | if (cntl.Failed()) { |
| 768 | ✗ | LOG(ERROR) << "bRPC Command failed: " << cntl.ErrorText(); | |
| 769 | ✗ | return false; | |
| 770 | } | ||
| 771 | 34 | return true; | |
| 772 | 34 | } | |
| 773 | |||
| 774 | 6 | bool BRPCParameterClient::LoadFakeData(int64_t data) { | |
| 775 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | CommandRequest request; |
| 776 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | CommandResponse response; |
| 777 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | request.set_command(PSCommand::LOAD_FAKE_DATA); |
| 778 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | request.add_arg1(&data, sizeof(int64_t)); |
| 779 | |||
| 780 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | brpc::Controller cntl; |
| 781 |
1/2✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
|
6 | recstoreps_brpc::ParameterService_Stub stub(channel_.get()); |
| 782 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | stub.Command(&cntl, &request, &response, nullptr); |
| 783 | |||
| 784 |
2/4✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
|
6 | if (cntl.Failed()) { |
| 785 | ✗ | LOG(ERROR) << "bRPC LoadFakeData failed: " << cntl.ErrorText(); | |
| 786 | ✗ | return false; | |
| 787 | } | ||
| 788 |
2/4✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 6 times.
|
6 | if (response.reply().size() != static_cast<size_t>(data)) { |
| 789 | ✗ | LOG(ERROR) << "bRPC LoadFakeData reply size mismatch: expected " << data | |
| 790 | ✗ | << ", got " << response.reply().size(); | |
| 791 | ✗ | return false; | |
| 792 | } | ||
| 793 | 6 | return true; | |
| 794 | 6 | } | |
| 795 | |||
| 796 | 6 | bool BRPCParameterClient::DumpFakeData(int64_t n) { | |
| 797 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | CommandRequest request; |
| 798 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | CommandResponse response; |
| 799 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | request.set_command(PSCommand::DUMP_FAKE_DATA); |
| 800 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | request.add_arg1(&n, sizeof(int64_t)); |
| 801 | |||
| 802 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | brpc::Controller cntl; |
| 803 |
1/2✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
|
6 | recstoreps_brpc::ParameterService_Stub stub(channel_.get()); |
| 804 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | stub.Command(&cntl, &request, &response, nullptr); |
| 805 | |||
| 806 |
2/4✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
|
6 | if (cntl.Failed()) { |
| 807 | ✗ | LOG(ERROR) << "bRPC DumpFakeData failed: " << cntl.ErrorText(); | |
| 808 | ✗ | return false; | |
| 809 | } | ||
| 810 |
2/4✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 6 times.
|
6 | if (response.reply() != "ok") { |
| 811 | ✗ | LOG(ERROR) << "bRPC DumpFakeData unexpected reply: " << response.reply(); | |
| 812 | ✗ | return false; | |
| 813 | } | ||
| 814 | 6 | return true; | |
| 815 | 6 | } | |
| 816 | |||
| 817 | ✗ | bool BRPCParameterClient::LoadCkpt( | |
| 818 | const std::vector<std::string>& model_config_path, | ||
| 819 | const std::vector<std::string>& emb_file_path) { | ||
| 820 | ✗ | CommandRequest request; | |
| 821 | ✗ | CommandResponse response; | |
| 822 | ✗ | request.set_command(PSCommand::RELOAD_PS); | |
| 823 | |||
| 824 | ✗ | for (auto& each : model_config_path) { | |
| 825 | ✗ | request.add_arg1(each); | |
| 826 | } | ||
| 827 | ✗ | for (auto& each : emb_file_path) { | |
| 828 | ✗ | request.add_arg2(each); | |
| 829 | } | ||
| 830 | |||
| 831 | ✗ | brpc::Controller cntl; | |
| 832 | ✗ | recstoreps_brpc::ParameterService_Stub stub(channel_.get()); | |
| 833 | ✗ | stub.Command(&cntl, &request, &response, nullptr); | |
| 834 | |||
| 835 | ✗ | if (cntl.Failed()) { | |
| 836 | ✗ | LOG(ERROR) << "bRPC LoadCkpt failed: " << cntl.ErrorText(); | |
| 837 | ✗ | return false; | |
| 838 | } | ||
| 839 | ✗ | return true; | |
| 840 | ✗ | } | |
| 841 | |||
| 842 | 20 | bool BRPCParameterClient::PutParameter( | |
| 843 | const std::vector<uint64_t>& keys, | ||
| 844 | const std::vector<std::vector<float>>& values) { | ||
| 845 | #ifdef ENABLE_PERF_REPORT | ||
| 846 | auto start_time = std::chrono::high_resolution_clock::now(); | ||
| 847 | #endif | ||
| 848 | |||
| 849 |
1/2✓ Branch 2 taken 20 times.
✗ Branch 3 not taken.
|
20 | recstoreps_brpc::ParameterService_Stub stub(channel_.get()); |
| 850 | |||
| 851 |
2/2✓ Branch 1 taken 20 times.
✓ Branch 2 taken 20 times.
|
40 | for (int start = 0, index = 0; start < keys.size(); |
| 852 | 20 | start += MAX_PARAMETER_BATCH_BRPC, ++index) { | |
| 853 | int key_size = | ||
| 854 | 20 | std::min((int)(keys.size() - start), MAX_PARAMETER_BATCH_BRPC); | |
| 855 | |||
| 856 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | PutParameterRequest request; |
| 857 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | PutParameterResponse response; |
| 858 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | ParameterCompressor compressor; |
| 859 | |||
| 860 |
2/2✓ Branch 0 taken 234 times.
✓ Branch 1 taken 20 times.
|
254 | for (int i = start; i < start + key_size; i++) { |
| 861 | 234 | auto each_key = keys[i]; | |
| 862 | 234 | auto& embedding = values[i]; | |
| 863 | 234 | ParameterPack parameter_pack; | |
| 864 | 234 | parameter_pack.key = each_key; | |
| 865 | 234 | parameter_pack.dim = embedding.size(); | |
| 866 | 234 | parameter_pack.emb_data = embedding.data(); | |
| 867 |
1/2✓ Branch 1 taken 234 times.
✗ Branch 2 not taken.
|
234 | compressor.AddItem(parameter_pack, nullptr); |
| 868 | } | ||
| 869 | |||
| 870 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | brpc::Controller cntl; |
| 871 |
1/2✓ Branch 2 taken 20 times.
✗ Branch 3 not taken.
|
20 | compressor.AppendToIOBuf(&cntl.request_attachment()); |
| 872 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | stub.PutParameter(&cntl, &request, &response, nullptr); |
| 873 | |||
| 874 |
2/4✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 20 times.
|
20 | if (cntl.Failed()) { |
| 875 | ✗ | LOG(ERROR) << "bRPC PutParameter failed: " << cntl.ErrorText(); | |
| 876 | ✗ | return false; | |
| 877 | } | ||
| 878 |
4/8✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 20 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 20 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 20 times.
✗ Branch 11 not taken.
|
20 | } |
| 879 | |||
| 880 | #ifdef ENABLE_PERF_REPORT | ||
| 881 | auto end_time = std::chrono::high_resolution_clock::now(); | ||
| 882 | auto duration = | ||
| 883 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 884 | end_time - start_time) | ||
| 885 | .count(); | ||
| 886 | report("ps_client_latency", | ||
| 887 | "PutParameter", | ||
| 888 | "latency_us", | ||
| 889 | static_cast<double>(duration)); | ||
| 890 | #endif | ||
| 891 | |||
| 892 | 20 | return true; | |
| 893 | 20 | } | |
| 894 | |||
| 895 | ✗ | int BRPCParameterClient::AsyncGetParameter( | |
| 896 | const base::ConstArray<uint64_t>& keys, float* values) { | ||
| 897 | ✗ | return GetParameter(keys, values); | |
| 898 | } | ||
| 899 | |||
| 900 | 14 | int BRPCParameterClient::PutParameter( | |
| 901 | const base::ConstArray<uint64_t>& keys, | ||
| 902 | const std::vector<std::vector<float>>& values) { | ||
| 903 |
1/2✓ Branch 5 taken 14 times.
✗ Branch 6 not taken.
|
14 | std::vector<uint64_t> key_vec(keys.Data(), keys.Data() + keys.Size()); |
| 904 |
1/2✓ Branch 1 taken 14 times.
✗ Branch 2 not taken.
|
14 | bool success = PutParameter(key_vec, values); |
| 905 |
1/2✓ Branch 0 taken 14 times.
✗ Branch 1 not taken.
|
28 | return success ? 1 : 0; |
| 906 | 14 | } | |
| 907 | |||
| 908 | ✗ | void BRPCParameterClient::Command(recstore::PSCommand command) { | |
| 909 | ✗ | switch (command) { | |
| 910 | ✗ | case recstore::PSCommand::CLEAR_PS: | |
| 911 | ✗ | ClearPS(); | |
| 912 | ✗ | break; | |
| 913 | ✗ | case recstore::PSCommand::RELOAD_PS: | |
| 914 | ✗ | LOG(WARNING) << "RELOAD_PS command requires additional parameters"; | |
| 915 | ✗ | break; | |
| 916 | ✗ | case recstore::PSCommand::LOAD_FAKE_DATA: { | |
| 917 | ✗ | int64_t fake_data = 1000; | |
| 918 | ✗ | LoadFakeData(fake_data); | |
| 919 | ✗ | } break; | |
| 920 | ✗ | case recstore::PSCommand::DUMP_FAKE_DATA: { | |
| 921 | ✗ | DumpFakeData(4096); | |
| 922 | ✗ | } break; | |
| 923 | ✗ | default: | |
| 924 | ✗ | LOG(ERROR) << "Unknown PS command: " << static_cast<int>(command); | |
| 925 | ✗ | break; | |
| 926 | } | ||
| 927 | ✗ | } | |
| 928 | |||
| 929 | ✗ | int BRPCParameterClient::UpdateParameter( | |
| 930 | const std::string& table_name, | ||
| 931 | const base::ConstArray<uint64_t>& keys, | ||
| 932 | const std::vector<std::vector<float>>* grads) { | ||
| 933 | #ifdef ENABLE_PERF_REPORT | ||
| 934 | auto start_time = std::chrono::high_resolution_clock::now(); | ||
| 935 | const uint64_t trace_id = recstore::g_trace_id; | ||
| 936 | #endif | ||
| 937 | ✗ | if (grads == nullptr) { | |
| 938 | ✗ | LOG(ERROR) << "UpdateParameter grads pointer is null"; | |
| 939 | ✗ | return -1; | |
| 940 | } | ||
| 941 | ✗ | if (keys.Size() != grads->size()) { | |
| 942 | ✗ | LOG(ERROR) << "UpdateParameter keys/grads size mismatch: " << keys.Size() | |
| 943 | ✗ | << " vs " << grads->size(); | |
| 944 | ✗ | return -1; | |
| 945 | } | ||
| 946 | |||
| 947 | ✗ | ParameterCompressor compressor; | |
| 948 | ✗ | for (size_t i = 0; i < keys.Size(); ++i) { | |
| 949 | ✗ | ParameterPack pack; | |
| 950 | ✗ | pack.key = keys[i]; | |
| 951 | ✗ | pack.dim = grads->at(i).size(); | |
| 952 | ✗ | pack.emb_data = grads->at(i).data(); | |
| 953 | ✗ | compressor.AddItem(pack, nullptr); | |
| 954 | } | ||
| 955 | #ifdef ENABLE_PERF_REPORT | ||
| 956 | auto serialize_done_time = std::chrono::high_resolution_clock::now(); | ||
| 957 | #endif | ||
| 958 | ✗ | if (keys.Size() == 0) { | |
| 959 | ✗ | LOG(WARNING) << "UpdateParameter no gradients to send"; | |
| 960 | ✗ | return 0; | |
| 961 | } | ||
| 962 | |||
| 963 | ✗ | UpdateParameterRequest request; | |
| 964 | ✗ | UpdateParameterResponse response; | |
| 965 | request.set_table_name(table_name); | ||
| 966 | |||
| 967 | ✗ | brpc::Controller cntl; | |
| 968 | #ifdef ENABLE_PERF_REPORT | ||
| 969 | if (trace_id != 0) { | ||
| 970 | cntl.http_request().SetHeader( | ||
| 971 | "x-recstore-trace-id", std::to_string(trace_id)); | ||
| 972 | } | ||
| 973 | auto rpc_start_time = std::chrono::high_resolution_clock::now(); | ||
| 974 | #endif | ||
| 975 | ✗ | compressor.AppendToIOBuf(&cntl.request_attachment()); | |
| 976 | ✗ | recstoreps_brpc::ParameterService_Stub stub(channel_.get()); | |
| 977 | ✗ | stub.UpdateParameter(&cntl, &request, &response, nullptr); | |
| 978 | ✗ | if (cntl.Failed()) { | |
| 979 | ✗ | LOG(ERROR) << "UpdateParameter RPC failed: " << cntl.ErrorText(); | |
| 980 | ✗ | return -1; | |
| 981 | } | ||
| 982 | |||
| 983 | #ifdef ENABLE_PERF_REPORT | ||
| 984 | auto end_time = std::chrono::high_resolution_clock::now(); | ||
| 985 | auto duration = | ||
| 986 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 987 | end_time - start_time) | ||
| 988 | .count(); | ||
| 989 | auto serialize_duration = | ||
| 990 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 991 | serialize_done_time - start_time) | ||
| 992 | .count(); | ||
| 993 | auto rpc_duration = | ||
| 994 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 995 | end_time - rpc_start_time) | ||
| 996 | .count(); | ||
| 997 | report("ps_client_latency", | ||
| 998 | "UpdateParameter", | ||
| 999 | "latency_us", | ||
| 1000 | static_cast<double>(duration)); | ||
| 1001 | |||
| 1002 | const uint64_t effective_trace_id = | ||
| 1003 | trace_id == 0 | ||
| 1004 | ? static_cast<uint64_t>( | ||
| 1005 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 1006 | start_time.time_since_epoch()) | ||
| 1007 | .count()) | ||
| 1008 | : trace_id; | ||
| 1009 | std::string stage_id = | ||
| 1010 | "brpc_client::EmbUpdate|" + std::to_string(effective_trace_id); | ||
| 1011 | report("embupdate_stages", | ||
| 1012 | stage_id.c_str(), | ||
| 1013 | "client_serialize_us", | ||
| 1014 | static_cast<double>(serialize_duration)); | ||
| 1015 | report("embupdate_stages", | ||
| 1016 | stage_id.c_str(), | ||
| 1017 | "client_rpc_us", | ||
| 1018 | static_cast<double>(rpc_duration)); | ||
| 1019 | report("embupdate_stages", | ||
| 1020 | stage_id.c_str(), | ||
| 1021 | "client_total_us", | ||
| 1022 | static_cast<double>(duration)); | ||
| 1023 | report("embupdate_stages", | ||
| 1024 | stage_id.c_str(), | ||
| 1025 | "client_request_size", | ||
| 1026 | static_cast<double>(keys.Size())); | ||
| 1027 | #endif | ||
| 1028 | |||
| 1029 | ✗ | return response.success() ? 0 : -1; | |
| 1030 | ✗ | } | |
| 1031 | |||
| 1032 | ✗ | int BRPCParameterClient::UpdateParameterFlat( | |
| 1033 | const std::string& table_name, | ||
| 1034 | const base::ConstArray<uint64_t>& keys, | ||
| 1035 | const float* grads, | ||
| 1036 | int64_t num_rows, | ||
| 1037 | int64_t embedding_dim) { | ||
| 1038 | #ifdef ENABLE_PERF_REPORT | ||
| 1039 | auto start_time = std::chrono::high_resolution_clock::now(); | ||
| 1040 | const uint64_t trace_id = recstore::g_trace_id; | ||
| 1041 | #endif | ||
| 1042 | ✗ | if (keys.Size() == 0) { | |
| 1043 | ✗ | return 0; | |
| 1044 | } | ||
| 1045 | |||
| 1046 | ✗ | ParameterCompressor compressor; | |
| 1047 | ✗ | if (BuildUpdateBlocksFromFlat( | |
| 1048 | ✗ | keys, grads, num_rows, embedding_dim, &compressor) != 0) { | |
| 1049 | ✗ | return -1; | |
| 1050 | } | ||
| 1051 | #ifdef ENABLE_PERF_REPORT | ||
| 1052 | auto serialize_done_time = std::chrono::high_resolution_clock::now(); | ||
| 1053 | #endif | ||
| 1054 | |||
| 1055 | ✗ | UpdateParameterRequest request; | |
| 1056 | ✗ | UpdateParameterResponse response; | |
| 1057 | request.set_table_name(table_name); | ||
| 1058 | |||
| 1059 | ✗ | brpc::Controller cntl; | |
| 1060 | #ifdef ENABLE_PERF_REPORT | ||
| 1061 | if (trace_id != 0) { | ||
| 1062 | cntl.http_request().SetHeader( | ||
| 1063 | "x-recstore-trace-id", std::to_string(trace_id)); | ||
| 1064 | } | ||
| 1065 | auto rpc_start_time = std::chrono::high_resolution_clock::now(); | ||
| 1066 | #endif | ||
| 1067 | ✗ | compressor.AppendToIOBuf(&cntl.request_attachment()); | |
| 1068 | ✗ | recstoreps_brpc::ParameterService_Stub stub(channel_.get()); | |
| 1069 | ✗ | stub.UpdateParameter(&cntl, &request, &response, nullptr); | |
| 1070 | ✗ | if (cntl.Failed()) { | |
| 1071 | ✗ | LOG(ERROR) << "UpdateParameterFlat RPC failed: " << cntl.ErrorText(); | |
| 1072 | ✗ | return -1; | |
| 1073 | } | ||
| 1074 | |||
| 1075 | #ifdef ENABLE_PERF_REPORT | ||
| 1076 | auto end_time = std::chrono::high_resolution_clock::now(); | ||
| 1077 | auto duration = | ||
| 1078 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 1079 | end_time - start_time) | ||
| 1080 | .count(); | ||
| 1081 | auto serialize_duration = | ||
| 1082 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 1083 | serialize_done_time - start_time) | ||
| 1084 | .count(); | ||
| 1085 | auto rpc_duration = | ||
| 1086 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 1087 | end_time - rpc_start_time) | ||
| 1088 | .count(); | ||
| 1089 | report("ps_client_latency", | ||
| 1090 | "UpdateParameterFlat", | ||
| 1091 | "latency_us", | ||
| 1092 | static_cast<double>(duration)); | ||
| 1093 | |||
| 1094 | const uint64_t effective_trace_id = | ||
| 1095 | trace_id == 0 | ||
| 1096 | ? static_cast<uint64_t>( | ||
| 1097 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 1098 | start_time.time_since_epoch()) | ||
| 1099 | .count()) | ||
| 1100 | : trace_id; | ||
| 1101 | std::string stage_id = | ||
| 1102 | "brpc_client::EmbUpdate|" + std::to_string(effective_trace_id); | ||
| 1103 | report("embupdate_stages", | ||
| 1104 | stage_id.c_str(), | ||
| 1105 | "client_serialize_us", | ||
| 1106 | static_cast<double>(serialize_duration)); | ||
| 1107 | report("embupdate_stages", | ||
| 1108 | stage_id.c_str(), | ||
| 1109 | "client_rpc_us", | ||
| 1110 | static_cast<double>(rpc_duration)); | ||
| 1111 | report("embupdate_stages", | ||
| 1112 | stage_id.c_str(), | ||
| 1113 | "client_total_us", | ||
| 1114 | static_cast<double>(duration)); | ||
| 1115 | report("embupdate_stages", | ||
| 1116 | stage_id.c_str(), | ||
| 1117 | "client_request_size", | ||
| 1118 | static_cast<double>(num_rows)); | ||
| 1119 | report("embupdate_stages", | ||
| 1120 | stage_id.c_str(), | ||
| 1121 | "client_embedding_dim", | ||
| 1122 | static_cast<double>(embedding_dim)); | ||
| 1123 | #endif | ||
| 1124 | |||
| 1125 | ✗ | return response.success() ? 0 : -1; | |
| 1126 | ✗ | } | |
| 1127 | |||
| 1128 | ✗ | int BRPCParameterClient::InitEmbeddingTable( | |
| 1129 | const std::string& table_name, | ||
| 1130 | const recstore::EmbeddingTableConfig& config) { | ||
| 1131 | #ifdef ENABLE_PERF_REPORT | ||
| 1132 | auto start_time = std::chrono::high_resolution_clock::now(); | ||
| 1133 | #endif | ||
| 1134 | |||
| 1135 | ✗ | InitEmbeddingTableRequest request; | |
| 1136 | ✗ | InitEmbeddingTableResponse response; | |
| 1137 | request.set_table_name(table_name); | ||
| 1138 | ✗ | request.set_config_payload(config.Serialize()); | |
| 1139 | |||
| 1140 | ✗ | brpc::Controller cntl; | |
| 1141 | ✗ | recstoreps_brpc::ParameterService_Stub stub(channel_.get()); | |
| 1142 | ✗ | stub.InitEmbeddingTable(&cntl, &request, &response, nullptr); | |
| 1143 | ✗ | if (cntl.Failed()) { | |
| 1144 | ✗ | LOG(ERROR) << "InitEmbeddingTable RPC failed: " << cntl.ErrorText(); | |
| 1145 | ✗ | return -1; | |
| 1146 | } | ||
| 1147 | |||
| 1148 | #ifdef ENABLE_PERF_REPORT | ||
| 1149 | auto end_time = std::chrono::high_resolution_clock::now(); | ||
| 1150 | auto duration = | ||
| 1151 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 1152 | end_time - start_time) | ||
| 1153 | .count(); | ||
| 1154 | report("ps_client_latency", | ||
| 1155 | "InitEmbeddingTable", | ||
| 1156 | "latency_us", | ||
| 1157 | static_cast<double>(duration)); | ||
| 1158 | #endif | ||
| 1159 | |||
| 1160 | ✗ | return response.success() ? 0 : -1; | |
| 1161 | ✗ | } | |
| 1162 | |||
| 1163 | 8 | uint64_t BRPCParameterClient::EmbWriteAsync(const base::RecTensor& keys, | |
| 1164 | const base::RecTensor& values) { | ||
| 1165 |
3/6✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 8 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 8 times.
|
8 | if (keys.dtype() != base::DataType::UINT64 || keys.dim() != 1) { |
| 1166 | ✗ | LOG(ERROR) << "EmbWriteAsync expects keys as 1D UINT64 tensor, got dtype=" | |
| 1167 | ✗ | << base::DataTypeToString(keys.dtype()) | |
| 1168 | ✗ | << ", dim=" << keys.dim(); | |
| 1169 | ✗ | return 0; | |
| 1170 | } | ||
| 1171 |
3/6✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 8 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 8 times.
|
8 | if (values.dtype() != base::DataType::FLOAT32 || values.dim() != 2) { |
| 1172 | ✗ | LOG(ERROR) | |
| 1173 | << "EmbWriteAsync expects values as 2D FLOAT32 tensor, got dtype=" | ||
| 1174 | ✗ | << base::DataTypeToString(values.dtype()) << ", dim=" << values.dim(); | |
| 1175 | ✗ | return 0; | |
| 1176 | } | ||
| 1177 |
3/6✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 8 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 8 times.
|
8 | if (values.shape(0) != keys.shape(0)) { |
| 1178 | ✗ | LOG(ERROR) << "EmbWriteAsync row mismatch: keys=" << keys.shape(0) | |
| 1179 | ✗ | << ", values=" << values.shape(0); | |
| 1180 | ✗ | return 0; | |
| 1181 | } | ||
| 1182 |
2/4✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 8 times.
|
8 | if (values.shape(1) <= 0) { |
| 1183 | ✗ | LOG(ERROR) << "EmbWriteAsync invalid embedding dim: " << values.shape(1); | |
| 1184 | ✗ | return 0; | |
| 1185 | } | ||
| 1186 | |||
| 1187 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | const uint64_t* key_data = keys.data_as<uint64_t>(); |
| 1188 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | const float* value_data = values.data_as<float>(); |
| 1189 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | int64_t key_count = keys.shape(0); |
| 1190 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | int64_t emb_dim = values.shape(1); |
| 1191 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
|
8 | if (key_count == 0) { |
| 1192 | ✗ | return 0; | |
| 1193 | } | ||
| 1194 | |||
| 1195 | 8 | uint64_t prewrite_id = next_prewrite_id_++; | |
| 1196 | 8 | int request_num = | |
| 1197 | 8 | (static_cast<int>(key_count) + MAX_PARAMETER_BATCH_BRPC - 1) / | |
| 1198 | MAX_PARAMETER_BATCH_BRPC; | ||
| 1199 | |||
| 1200 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | auto it = prewrite_batches_.emplace(prewrite_id, request_num).first; |
| 1201 | 8 | struct BrpcPrewriteBatch* pb = &it->second; | |
| 1202 | |||
| 1203 |
1/2✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
|
8 | recstoreps_brpc::ParameterService_Stub stub(channel_.get()); |
| 1204 |
2/2✓ Branch 0 taken 8 times.
✓ Branch 1 taken 8 times.
|
16 | for (int start = 0, index = 0; start < key_count; |
| 1205 | 8 | start += MAX_PARAMETER_BATCH_BRPC, ++index) { | |
| 1206 | int key_size = | ||
| 1207 | 8 | std::min(static_cast<int>(key_count - start), MAX_PARAMETER_BATCH_BRPC); | |
| 1208 | 8 | pb->key_sizes_[index] = key_size; | |
| 1209 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | pb->controllers_[index] = std::make_unique<brpc::Controller>(); |
| 1210 | |||
| 1211 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | ParameterCompressor compressor; |
| 1212 |
2/2✓ Branch 0 taken 96 times.
✓ Branch 1 taken 8 times.
|
104 | for (int i = 0; i < key_size; ++i) { |
| 1213 | 96 | int64_t row = start + i; | |
| 1214 | 96 | ParameterPack parameter_pack; | |
| 1215 | 96 | parameter_pack.key = key_data[row]; | |
| 1216 | 96 | parameter_pack.dim = emb_dim; | |
| 1217 | 96 | parameter_pack.emb_data = value_data + row * emb_dim; | |
| 1218 |
1/2✓ Branch 1 taken 96 times.
✗ Branch 2 not taken.
|
96 | compressor.AddItem(parameter_pack, nullptr); |
| 1219 | } | ||
| 1220 | |||
| 1221 |
1/2✓ Branch 4 taken 8 times.
✗ Branch 5 not taken.
|
8 | compressor.AppendToIOBuf(&pb->controllers_[index]->request_attachment()); |
| 1222 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | google::protobuf::Closure* done = brpc::NewCallback(OnPrewriteDone, pb); |
| 1223 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | stub.PutParameter( |
| 1224 | 8 | pb->controllers_[index].get(), | |
| 1225 | 8 | &pb->requests_[index], | |
| 1226 | 8 | &pb->responses_[index], | |
| 1227 | done); | ||
| 1228 | 8 | } | |
| 1229 | |||
| 1230 | 8 | return prewrite_id; | |
| 1231 | 8 | } | |
| 1232 | |||
| 1233 | ✗ | bool BRPCParameterClient::IsWriteDone(uint64_t write_id) { | |
| 1234 | ✗ | auto it = prewrite_batches_.find(write_id); | |
| 1235 | ✗ | if (it == prewrite_batches_.end()) { | |
| 1236 | ✗ | LOG(ERROR) << "Invalid prewrite_id: " << write_id; | |
| 1237 | ✗ | return false; | |
| 1238 | } | ||
| 1239 | ✗ | auto& pb = it->second; | |
| 1240 | ✗ | return pb.completed_count_ == pb.batch_size_; | |
| 1241 | } | ||
| 1242 | |||
| 1243 | 8 | void BRPCParameterClient::WaitForWrite(uint64_t write_id) { | |
| 1244 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | auto it = prewrite_batches_.find(write_id); |
| 1245 |
1/2✗ Branch 2 not taken.
✓ Branch 3 taken 8 times.
|
8 | if (it == prewrite_batches_.end()) { |
| 1246 | ✗ | LOG(ERROR) << "Invalid prewrite_id: " << write_id; | |
| 1247 | ✗ | return; | |
| 1248 | } | ||
| 1249 | 8 | auto& pb = it->second; | |
| 1250 |
2/2✓ Branch 0 taken 8 times.
✓ Branch 1 taken 8 times.
|
16 | for (int i = 0; i < pb.batch_size_; ++i) { |
| 1251 |
1/2✗ Branch 2 not taken.
✓ Branch 3 taken 8 times.
|
8 | if (!pb.controllers_[i]) { |
| 1252 | ✗ | continue; | |
| 1253 | } | ||
| 1254 |
2/4✓ Branch 3 taken 8 times.
✗ Branch 4 not taken.
✓ Branch 6 taken 8 times.
✗ Branch 7 not taken.
|
8 | brpc::Join(pb.controllers_[i]->call_id()); |
| 1255 |
2/4✓ Branch 3 taken 8 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 8 times.
|
8 | if (pb.controllers_[i]->Failed()) { |
| 1256 | ✗ | LOG(ERROR) << "Async PutParameter failed: " | |
| 1257 | ✗ | << pb.controllers_[i]->ErrorText(); | |
| 1258 | } | ||
| 1259 | } | ||
| 1260 | 8 | pb.completed_count_ = pb.batch_size_; | |
| 1261 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | prewrite_batches_.erase(it); |
| 1262 | } | ||
| 1263 | |||
| 1264 | // Register BRPCParameterClient with the factory | ||
| 1265 | using BasePSClient = recstore::BasePSClient; | ||
| 1266 | FACTORY_REGISTER(BasePSClient, brpc, BRPCParameterClient, json); | ||
| 1267 |