ps/grpc/grpc_ps_client.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #include "grpc_ps_client.h" | ||
| 2 | |||
| 3 | #include <fmt/core.h> | ||
| 4 | #include <grpcpp/grpcpp.h> | ||
| 5 | |||
| 6 | #include <cstdint> | ||
| 7 | #include <cstring> | ||
| 8 | #include <future> | ||
| 9 | #include <string> | ||
| 10 | #include <vector> | ||
| 11 | |||
| 12 | #include "base/array.h" | ||
| 13 | #include "base/factory.h" | ||
| 14 | #include "base/flatc.h" | ||
| 15 | #include "base/log.h" | ||
| 16 | #include "base/timer.h" | ||
| 17 | #include "ps/base/parameters.h" | ||
| 18 | #include "ps.grpc.pb.h" | ||
| 19 | #include "ps.pb.h" | ||
| 20 | |||
| 21 | #ifdef ENABLE_PERF_REPORT | ||
| 22 | # include "base/report/report_client.h" | ||
| 23 | # include <chrono> | ||
| 24 | #endif | ||
| 25 | |||
| 26 | using grpc::Channel; | ||
| 27 | using grpc::ClientAsyncResponseReader; | ||
| 28 | using grpc::ClientContext; | ||
| 29 | using grpc::Status; | ||
| 30 | using recstoreps::CommandRequest; | ||
| 31 | using recstoreps::CommandResponse; | ||
| 32 | using recstoreps::GetParameterRequest; | ||
| 33 | using recstoreps::GetParameterResponse; | ||
| 34 | using recstoreps::InitEmbeddingTableRequest; | ||
| 35 | using recstoreps::InitEmbeddingTableResponse; | ||
| 36 | using recstoreps::PSCommand; | ||
| 37 | using recstoreps::PutParameterRequest; | ||
| 38 | using recstoreps::PutParameterResponse; | ||
| 39 | using recstoreps::UpdateParameterRequest; | ||
| 40 | using recstoreps::UpdateParameterResponse; | ||
| 41 | |||
| 42 | namespace { | ||
| 43 | |||
| 44 | 198 | void SetRpcDeadline(grpc::ClientContext* context, int timeout_ms = 15000) { | |
| 45 |
1/2✓ Branch 1 taken 198 times.
✗ Branch 2 not taken.
|
198 | context->set_deadline( |
| 46 |
1/2✓ Branch 3 taken 198 times.
✗ Branch 4 not taken.
|
198 | std::chrono::system_clock::now() + std::chrono::milliseconds(timeout_ms)); |
| 47 | 198 | } | |
| 48 | |||
| 49 | ✗ | int BuildUpdateBlocksFromFlat( | |
| 50 | const base::ConstArray<uint64_t>& keys, | ||
| 51 | const float* grads, | ||
| 52 | int64_t num_rows, | ||
| 53 | int64_t embedding_dim, | ||
| 54 | ParameterCompressor* compressor) { | ||
| 55 | ✗ | if (grads == nullptr) { | |
| 56 | ✗ | LOG(ERROR) << "UpdateParameterFlat grads pointer is null"; | |
| 57 | ✗ | return -1; | |
| 58 | } | ||
| 59 | ✗ | if (num_rows < 0 || embedding_dim <= 0) { | |
| 60 | ✗ | LOG(ERROR) << "UpdateParameterFlat invalid shape: rows=" << num_rows | |
| 61 | ✗ | << " dim=" << embedding_dim; | |
| 62 | ✗ | return -1; | |
| 63 | } | ||
| 64 | ✗ | if (keys.Size() != static_cast<size_t>(num_rows)) { | |
| 65 | ✗ | LOG(ERROR) << "UpdateParameterFlat keys/grads size mismatch: " | |
| 66 | ✗ | << keys.Size() << " vs " << num_rows; | |
| 67 | ✗ | return -1; | |
| 68 | } | ||
| 69 | |||
| 70 | ✗ | for (int64_t i = 0; i < num_rows; ++i) { | |
| 71 | ✗ | ParameterPack pack; | |
| 72 | ✗ | pack.key = keys[static_cast<size_t>(i)]; | |
| 73 | ✗ | pack.dim = embedding_dim; | |
| 74 | ✗ | pack.emb_data = grads + i * embedding_dim; | |
| 75 | ✗ | compressor->AddItem(pack, nullptr); | |
| 76 | } | ||
| 77 | ✗ | return 0; | |
| 78 | } | ||
| 79 | |||
| 80 | } // namespace | ||
| 81 | |||
| 82 | DEFINE_int32(get_parameter_threads, 4, "get clients per shard"); | ||
| 83 | DEFINE_bool(parameter_client_random_init, false, ""); | ||
| 84 | |||
| 85 | // New constructor that takes JSON config. | ||
| 86 | /* | ||
| 87 | Example: load config from file | ||
| 88 | std::ifstream config_file(FLAGS_config_path); | ||
| 89 | nlohmann::json ex; | ||
| 90 | config_file >> ex; | ||
| 91 | json client_config = ex["client"]; | ||
| 92 | |||
| 93 | */ | ||
| 94 | 38 | GRPCParameterClient::GRPCParameterClient(json config) | |
| 95 |
1/2✓ Branch 2 taken 38 times.
✗ Branch 3 not taken.
|
38 | : recstore::BasePSClient(config) { |
| 96 | // Extract fields from JSON config | ||
| 97 |
1/2✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
|
38 | host_ = config.value("host", "localhost"); |
| 98 |
1/2✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
|
38 | port_ = config.value("port", 15000); |
| 99 |
1/2✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
|
38 | shard_ = config.value("shard", 0); |
| 100 | 38 | nr_clients_ = FLAGS_get_parameter_threads; | |
| 101 | 38 | Initialize(); | |
| 102 | |||
| 103 |
1/2✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
|
38 | grpc::ChannelArguments args; |
| 104 |
1/2✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
|
38 | args.SetMaxReceiveMessageSize(-1); |
| 105 |
1/2✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
|
38 | args.SetMaxSendMessageSize(-1); |
| 106 | |||
| 107 | 38 | channel_ = grpc::CreateCustomChannel( | |
| 108 |
2/4✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 38 times.
✗ Branch 5 not taken.
|
114 | fmt::format("{}:{}", host_, port_), |
| 109 |
1/2✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
|
76 | grpc::InsecureChannelCredentials(), |
| 110 | 38 | args); | |
| 111 |
2/4✓ Branch 1 taken 38 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 38 times.
✗ Branch 5 not taken.
|
38 | auto* raw_cq = new grpc::CompletionQueue(); |
| 112 | 38 | cq_.reset(raw_cq); | |
| 113 | |||
| 114 |
2/2✓ Branch 0 taken 152 times.
✓ Branch 1 taken 38 times.
|
190 | for (int i = 0; i < nr_clients_; i++) { |
| 115 |
1/2✓ Branch 2 taken 152 times.
✗ Branch 3 not taken.
|
152 | stubs_.push_back(nullptr); |
| 116 |
1/2✓ Branch 3 taken 152 times.
✗ Branch 4 not taken.
|
152 | stubs_[i] = recstoreps::ParameterService::NewStub(channel_); |
| 117 |
4/8✓ Branch 1 taken 152 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 152 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 152 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 152 times.
✗ Branch 11 not taken.
|
152 | LOG(INFO) << "Init PS Client Shard " << i; |
| 118 | } | ||
| 119 | 38 | } | |
| 120 | |||
| 121 | // Legacy constructor for backward compatibility | ||
| 122 | 2 | GRPCParameterClient::GRPCParameterClient( | |
| 123 | 2 | const std::string& host, int port, int shard) | |
| 124 | : recstore::BasePSClient( | ||
| 125 |
14/28✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 9 taken 2 times.
✗ Branch 10 not taken.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✓ Branch 19 taken 2 times.
✗ Branch 20 not taken.
✓ Branch 22 taken 6 times.
✓ Branch 23 taken 2 times.
✓ Branch 25 taken 4 times.
✓ Branch 26 taken 2 times.
✓ Branch 28 taken 4 times.
✓ Branch 29 taken 2 times.
✓ Branch 31 taken 4 times.
✓ Branch 32 taken 2 times.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
|
22 | json{{"host", host}, {"port", port}, {"shard", shard}}), |
| 126 | 2 | host_(host), | |
| 127 | 2 | port_(port), | |
| 128 | 2 | shard_(shard), | |
| 129 |
2/4✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 2 times.
✗ Branch 6 not taken.
|
6 | nr_clients_(FLAGS_get_parameter_threads) { |
| 130 | 2 | Initialize(); | |
| 131 | |||
| 132 |
1/2✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
|
2 | grpc::ChannelArguments args; |
| 133 |
1/2✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
|
2 | args.SetMaxReceiveMessageSize(-1); |
| 134 |
1/2✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
|
2 | args.SetMaxSendMessageSize(-1); |
| 135 | |||
| 136 | 2 | channel_ = grpc::CreateCustomChannel( | |
| 137 |
1/2✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
|
4 | fmt::format("{}:{}", host, port), |
| 138 |
1/2✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
|
4 | grpc::InsecureChannelCredentials(), |
| 139 | 2 | args); | |
| 140 |
2/4✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
|
2 | auto* raw_cq = new grpc::CompletionQueue(); |
| 141 | 2 | cq_.reset(raw_cq); | |
| 142 | |||
| 143 |
2/2✓ Branch 0 taken 8 times.
✓ Branch 1 taken 2 times.
|
10 | for (int i = 0; i < nr_clients_; i++) { |
| 144 |
1/2✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
|
8 | stubs_.push_back(nullptr); |
| 145 |
1/2✓ Branch 3 taken 8 times.
✗ Branch 4 not taken.
|
8 | stubs_[i] = recstoreps::ParameterService::NewStub(channel_); |
| 146 |
4/8✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 8 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 8 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 11 not taken.
|
8 | LOG(INFO) << "Init PS Client Shard " << i; |
| 147 | } | ||
| 148 | 2 | } | |
| 149 | |||
| 150 | ✗ | int GRPCParameterClient::GetParameter(const base::ConstArray<uint64_t>& keys, | |
| 151 | float* values) { | ||
| 152 | #ifdef ENABLE_PERF_REPORT | ||
| 153 | auto start_time = std::chrono::high_resolution_clock::now(); | ||
| 154 | #endif | ||
| 155 | |||
| 156 | ✗ | if (FLAGS_parameter_client_random_init) { | |
| 157 | ✗ | CHECK(0) << "todo implement"; | |
| 158 | return true; | ||
| 159 | } | ||
| 160 | |||
| 161 | ✗ | get_param_key_sizes_.clear(); | |
| 162 | ✗ | get_param_status_.clear(); | |
| 163 | ✗ | get_param_requests_.clear(); | |
| 164 | ✗ | get_param_responses_.clear(); | |
| 165 | ✗ | get_param_resonse_readers_.clear(); | |
| 166 | ✗ | get_param_contexts_.clear(); | |
| 167 | |||
| 168 | int request_num = | ||
| 169 | ✗ | (keys.Size() + MAX_PARAMETER_BATCH - 1) / MAX_PARAMETER_BATCH; | |
| 170 | ✗ | get_param_status_.resize(request_num); | |
| 171 | ✗ | get_param_requests_.resize(request_num); | |
| 172 | ✗ | get_param_responses_.resize(request_num); | |
| 173 | ✗ | get_param_contexts_.resize(request_num); | |
| 174 | |||
| 175 | ✗ | for (int start = 0, index = 0; start < keys.Size(); | |
| 176 | ✗ | start += MAX_PARAMETER_BATCH, ++index) { | |
| 177 | ✗ | int key_size = std::min((int)(keys.Size() - start), MAX_PARAMETER_BATCH); | |
| 178 | ✗ | get_param_key_sizes_.emplace_back(key_size); | |
| 179 | ✗ | auto& status = get_param_status_[index]; | |
| 180 | ✗ | auto& request = get_param_requests_[index]; | |
| 181 | ✗ | auto& response = get_param_responses_[index]; | |
| 182 | ✗ | request.set_keys(reinterpret_cast<const char*>(&keys[start]), | |
| 183 | ✗ | sizeof(uint64_t) * key_size); | |
| 184 | // rpc | ||
| 185 | // grpc::ClientContext context; | ||
| 186 | ✗ | if (!get_param_contexts_[index]) { | |
| 187 | ✗ | get_param_contexts_[index] = std::make_unique<grpc::ClientContext>(); | |
| 188 | } | ||
| 189 | std::unique_ptr<ClientAsyncResponseReader<GetParameterResponse>> rpc = | ||
| 190 | ✗ | stubs_[0]->AsyncGetParameter( | |
| 191 | ✗ | get_param_contexts_[index].get(), request, cq_.get()); | |
| 192 | ✗ | rpc->Finish(&response, &status, reinterpret_cast<void*>(index)); | |
| 193 | ✗ | } | |
| 194 | ✗ | int get = 0; | |
| 195 | ✗ | while (get != request_num) { | |
| 196 | void* got_tag; | ||
| 197 | ✗ | bool ok = false; | |
| 198 | ✗ | cq_->Next(&got_tag, &ok); | |
| 199 | ✗ | if (!ok) { | |
| 200 | ✗ | LOG(ERROR) << "error"; | |
| 201 | } | ||
| 202 | ✗ | get++; | |
| 203 | } | ||
| 204 | #ifdef ENABLE_PERF_REPORT | ||
| 205 | auto after_rpc_time = std::chrono::high_resolution_clock::now(); | ||
| 206 | auto rpc_duration = | ||
| 207 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 208 | after_rpc_time - start_time) | ||
| 209 | .count(); | ||
| 210 | double start_us_for_rpc = | ||
| 211 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 212 | start_time.time_since_epoch()) | ||
| 213 | .count(); | ||
| 214 | std::string report_id_for_rpc = | ||
| 215 | "grpc_client::GetParameter|" + | ||
| 216 | std::to_string(static_cast<uint64_t>(start_us_for_rpc)); | ||
| 217 | report("embread_stages", | ||
| 218 | report_id_for_rpc.c_str(), | ||
| 219 | "rpc_duration_us", | ||
| 220 | static_cast<double>(rpc_duration)); | ||
| 221 | #endif | ||
| 222 | ✗ | size_t get_embedding_acc = 0; | |
| 223 | ✗ | int old_dimension = -1; | |
| 224 | |||
| 225 | ✗ | for (int i = 0; i < get_param_responses_.size(); ++i) { | |
| 226 | ✗ | auto& response = get_param_responses_[i]; | |
| 227 | ✗ | int key_size = get_param_key_sizes_[i]; | |
| 228 | auto parameters = reinterpret_cast<const ParameterCompressReader*>( | ||
| 229 | ✗ | response.parameter_value().data()); | |
| 230 | |||
| 231 | ✗ | if (parameters->size != key_size) { | |
| 232 | ✗ | LOG(ERROR) << "GetParameter error: " << parameters->size << " vs " | |
| 233 | ✗ | << key_size; | |
| 234 | ✗ | return false; | |
| 235 | } | ||
| 236 | |||
| 237 | ✗ | for (int index = 0; index < parameters->item_size(); ++index) { | |
| 238 | ✗ | auto item = parameters->item(index); | |
| 239 | ✗ | if (item->dim != 0) { | |
| 240 | ✗ | if (old_dimension == -1) | |
| 241 | ✗ | old_dimension = item->dim; | |
| 242 | ✗ | CHECK_EQ(item->dim, old_dimension); | |
| 243 | ✗ | std::copy_n( | |
| 244 | ✗ | item->embedding, item->dim, values + item->dim * get_embedding_acc); | |
| 245 | } else { | ||
| 246 | ✗ | RECSTORE_LOG_EVERY_MS(ERROR, 2000) | |
| 247 | ✗ | << "error; not find key " << keys[get_embedding_acc] << " in ps"; | |
| 248 | } | ||
| 249 | ✗ | get_embedding_acc++; | |
| 250 | } | ||
| 251 | } | ||
| 252 | |||
| 253 | #ifdef ENABLE_PERF_REPORT | ||
| 254 | auto end_time = std::chrono::high_resolution_clock::now(); | ||
| 255 | auto duration = | ||
| 256 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 257 | end_time - start_time) | ||
| 258 | .count(); | ||
| 259 | double start_us = | ||
| 260 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 261 | start_time.time_since_epoch()) | ||
| 262 | .count(); | ||
| 263 | |||
| 264 | auto deserialize_duration = | ||
| 265 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 266 | end_time - after_rpc_time) | ||
| 267 | .count(); | ||
| 268 | |||
| 269 | report("embread_stages", | ||
| 270 | "grpc_client::GetParameter", | ||
| 271 | "deserialize_duration_us", | ||
| 272 | static_cast<double>(deserialize_duration)); | ||
| 273 | |||
| 274 | report("embread_stages", | ||
| 275 | "grpc_client::GetParameter", | ||
| 276 | "duration_us", | ||
| 277 | static_cast<double>(duration)); | ||
| 278 | |||
| 279 | report("embread_stages", | ||
| 280 | "grpc_client::GetParameter", | ||
| 281 | "request_size", | ||
| 282 | static_cast<double>(keys.Size())); | ||
| 283 | |||
| 284 | FlameGraphData grpc_client_data = { | ||
| 285 | "grpc_ps_client::GetParameter", | ||
| 286 | start_us, | ||
| 287 | 1, // level | ||
| 288 | static_cast<double>(duration), | ||
| 289 | static_cast<double>(duration)}; | ||
| 290 | |||
| 291 | std::string unique_id = "embread_debug"; | ||
| 292 | report_flame_graph("emb_read_flame_map", unique_id.c_str(), grpc_client_data); | ||
| 293 | #endif | ||
| 294 | |||
| 295 | ✗ | return true; | |
| 296 | } | ||
| 297 | |||
| 298 | 92 | int GRPCParameterClient::GetParameter(const base::ConstArray<uint64_t>& keys, | |
| 299 | std::vector<std::vector<float>>* values) { | ||
| 300 | #ifdef ENABLE_PERF_REPORT | ||
| 301 | auto start_time = std::chrono::high_resolution_clock::now(); | ||
| 302 | #endif | ||
| 303 | |||
| 304 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 92 times.
|
92 | if (FLAGS_parameter_client_random_init) { |
| 305 | ✗ | values->clear(); | |
| 306 | ✗ | values->reserve(keys.Size()); | |
| 307 | ✗ | for (size_t i = 0; i < keys.Size(); i++) | |
| 308 | ✗ | values->emplace_back(std::vector<float>(128, 0.1)); | |
| 309 | |||
| 310 | ✗ | return true; | |
| 311 | } | ||
| 312 | |||
| 313 | 92 | values->clear(); | |
| 314 | 92 | get_param_key_sizes_.clear(); | |
| 315 | 92 | get_param_status_.clear(); | |
| 316 | 92 | get_param_requests_.clear(); | |
| 317 | 92 | get_param_responses_.clear(); | |
| 318 | 92 | get_param_resonse_readers_.clear(); | |
| 319 | 92 | get_param_contexts_.clear(); | |
| 320 | |||
| 321 | 92 | values->reserve(keys.Size()); | |
| 322 | |||
| 323 | int request_num = | ||
| 324 | 92 | (keys.Size() + MAX_PARAMETER_BATCH - 1) / MAX_PARAMETER_BATCH; | |
| 325 | |||
| 326 | 92 | get_param_status_.resize(request_num); | |
| 327 | 92 | get_param_requests_.resize(request_num); | |
| 328 | 92 | get_param_responses_.resize(request_num); | |
| 329 | 92 | get_param_contexts_.resize(request_num); | |
| 330 | |||
| 331 |
2/2✓ Branch 1 taken 92 times.
✓ Branch 2 taken 92 times.
|
184 | for (int start = 0, index = 0; start < keys.Size(); |
| 332 | 92 | start += MAX_PARAMETER_BATCH, ++index) { | |
| 333 | 92 | int key_size = std::min((int)(keys.Size() - start), MAX_PARAMETER_BATCH); | |
| 334 |
1/2✓ Branch 1 taken 92 times.
✗ Branch 2 not taken.
|
92 | get_param_key_sizes_.emplace_back(key_size); |
| 335 | 92 | auto& status = get_param_status_[index]; | |
| 336 | 92 | auto& request = get_param_requests_[index]; | |
| 337 | 92 | auto& response = get_param_responses_[index]; | |
| 338 | 92 | request.set_keys(reinterpret_cast<const char*>(&keys[start]), | |
| 339 | 92 | sizeof(uint64_t) * key_size); | |
| 340 | // rpc | ||
| 341 | // grpc::ClientContext context; | ||
| 342 |
1/2✓ Branch 2 taken 92 times.
✗ Branch 3 not taken.
|
92 | if (!get_param_contexts_[index]) { |
| 343 |
1/2✓ Branch 1 taken 92 times.
✗ Branch 2 not taken.
|
92 | get_param_contexts_[index] = std::make_unique<grpc::ClientContext>(); |
| 344 | } | ||
| 345 |
2/4✓ Branch 5 taken 92 times.
✗ Branch 6 not taken.
✓ Branch 8 taken 92 times.
✗ Branch 9 not taken.
|
184 | get_param_resonse_readers_.emplace_back(stubs_[0]->AsyncGetParameter( |
| 346 | 92 | get_param_contexts_[index].get(), request, cq_.get())); | |
| 347 | 92 | auto& rpc = get_param_resonse_readers_.back(); | |
| 348 | // GetParameter(&context, request, &response); | ||
| 349 |
1/2✓ Branch 2 taken 92 times.
✗ Branch 3 not taken.
|
92 | rpc->Finish(&response, &status, reinterpret_cast<void*>(index)); |
| 350 | } | ||
| 351 | |||
| 352 | 92 | int get = 0; | |
| 353 |
2/2✓ Branch 0 taken 92 times.
✓ Branch 1 taken 92 times.
|
184 | while (get != request_num) { |
| 354 | void* got_tag; | ||
| 355 | 92 | bool ok = false; | |
| 356 |
1/2✓ Branch 2 taken 92 times.
✗ Branch 3 not taken.
|
92 | cq_->Next(&got_tag, &ok); |
| 357 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 92 times.
|
92 | if (unlikely(!ok)) { |
| 358 | ✗ | LOG(ERROR) << "error"; | |
| 359 | } | ||
| 360 | 92 | get++; | |
| 361 | } | ||
| 362 | |||
| 363 | #ifdef ENABLE_PERF_REPORT | ||
| 364 | auto after_rpc_time = std::chrono::high_resolution_clock::now(); | ||
| 365 | auto rpc_duration = | ||
| 366 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 367 | after_rpc_time - start_time) | ||
| 368 | .count(); | ||
| 369 | double start_us_for_rpc = | ||
| 370 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 371 | start_time.time_since_epoch()) | ||
| 372 | .count(); | ||
| 373 | std::string report_id_for_rpc = | ||
| 374 | "grpc_client::GetParameter|" + | ||
| 375 | std::to_string(static_cast<uint64_t>(start_us_for_rpc)); | ||
| 376 | report("embread_stages", | ||
| 377 | report_id_for_rpc.c_str(), | ||
| 378 | "rpc_duration_us", | ||
| 379 | static_cast<double>(rpc_duration)); | ||
| 380 | #endif | ||
| 381 | |||
| 382 |
2/2✓ Branch 1 taken 92 times.
✓ Branch 2 taken 92 times.
|
184 | for (int i = 0; i < get_param_responses_.size(); ++i) { |
| 383 | 92 | auto& response = get_param_responses_[i]; | |
| 384 | 92 | int key_size = get_param_key_sizes_[i]; | |
| 385 | auto parameters = reinterpret_cast<const ParameterCompressReader*>( | ||
| 386 | 92 | response.parameter_value().data()); | |
| 387 | |||
| 388 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 92 times.
|
92 | if (unlikely(parameters->size != key_size)) { |
| 389 | ✗ | LOG(ERROR) << "GetParameter error: " << parameters->size << " vs " | |
| 390 | ✗ | << key_size; | |
| 391 | ✗ | return false; | |
| 392 | } | ||
| 393 | |||
| 394 |
2/2✓ Branch 1 taken 380 times.
✓ Branch 2 taken 92 times.
|
472 | for (int index = 0; index < parameters->item_size(); ++index) { |
| 395 | 380 | auto item = parameters->item(index); | |
| 396 |
2/2✓ Branch 0 taken 368 times.
✓ Branch 1 taken 12 times.
|
380 | if (item->dim != 0) { |
| 397 |
1/2✓ Branch 1 taken 368 times.
✗ Branch 2 not taken.
|
368 | values->emplace_back( |
| 398 |
1/2✓ Branch 2 taken 368 times.
✗ Branch 3 not taken.
|
736 | std::vector<float>(item->embedding, item->embedding + item->dim)); |
| 399 | } else { | ||
| 400 |
2/4✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 12 times.
✗ Branch 6 not taken.
|
12 | values->emplace_back(std::vector<float>(0)); |
| 401 | } | ||
| 402 | } | ||
| 403 | } | ||
| 404 | |||
| 405 | #ifdef ENABLE_PERF_REPORT | ||
| 406 | auto end_time = std::chrono::high_resolution_clock::now(); | ||
| 407 | auto duration = | ||
| 408 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 409 | end_time - start_time) | ||
| 410 | .count(); | ||
| 411 | double start_us = | ||
| 412 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 413 | start_time.time_since_epoch()) | ||
| 414 | .count(); | ||
| 415 | |||
| 416 | auto deserialize_duration = | ||
| 417 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 418 | end_time - after_rpc_time) | ||
| 419 | .count(); | ||
| 420 | |||
| 421 | report("embread_stages", | ||
| 422 | "grpc_client::GetParameter", | ||
| 423 | "deserialize_duration_us", | ||
| 424 | static_cast<double>(deserialize_duration)); | ||
| 425 | |||
| 426 | report("embread_stages", | ||
| 427 | "grpc_client::GetParameter", | ||
| 428 | "duration_us", | ||
| 429 | static_cast<double>(duration)); | ||
| 430 | |||
| 431 | report("embread_stages", | ||
| 432 | "grpc_client::GetParameter", | ||
| 433 | "request_size", | ||
| 434 | static_cast<double>(keys.Size())); | ||
| 435 | |||
| 436 | FlameGraphData grpc_client_data = { | ||
| 437 | "grpc_ps_client::GetParameter", | ||
| 438 | start_us, | ||
| 439 | 1, // level | ||
| 440 | static_cast<double>(duration), | ||
| 441 | static_cast<double>(duration)}; | ||
| 442 | |||
| 443 | std::string unique_id = "embread_debug"; | ||
| 444 | report_flame_graph("emb_read_flame_map", unique_id.c_str(), grpc_client_data); | ||
| 445 | #endif | ||
| 446 | |||
| 447 | 92 | return true; | |
| 448 | } | ||
| 449 | |||
| 450 | // return prefetch id | ||
| 451 | uint64_t | ||
| 452 | 112 | GRPCParameterClient::PrefetchParameter(const base::ConstArray<uint64_t>& keys) { | |
| 453 | int request_num = | ||
| 454 | 112 | (keys.Size() + MAX_PARAMETER_BATCH - 1) / MAX_PARAMETER_BATCH; | |
| 455 | |||
| 456 |
1/2✓ Branch 1 taken 112 times.
✗ Branch 2 not taken.
|
112 | struct PrefetchBatch pb(request_num); |
| 457 | |||
| 458 |
2/2✓ Branch 1 taken 112 times.
✓ Branch 2 taken 112 times.
|
224 | for (int start = 0, index = 0; start < keys.Size(); |
| 459 | 112 | start += MAX_PARAMETER_BATCH, ++index) { | |
| 460 | 112 | int key_size = std::min((int)(keys.Size() - start), MAX_PARAMETER_BATCH); | |
| 461 | 112 | pb.key_sizes_[index] = key_size; | |
| 462 | 112 | auto& status = pb.status_[index]; | |
| 463 |
1/2✓ Branch 2 taken 112 times.
✗ Branch 3 not taken.
|
112 | if (!pb.contexts_[index]) { |
| 464 |
1/2✓ Branch 1 taken 112 times.
✗ Branch 2 not taken.
|
112 | pb.contexts_[index] = std::make_unique<grpc::ClientContext>(); |
| 465 | } | ||
| 466 | 112 | auto& request = pb.requests_[index]; | |
| 467 | 112 | auto& response = pb.responses_[index]; | |
| 468 | 112 | request.set_keys(reinterpret_cast<const char*>(&keys[start]), | |
| 469 | 112 | sizeof(uint64_t) * key_size); | |
| 470 | // rpc | ||
| 471 | // grpc::ClientContext context; | ||
| 472 |
2/4✓ Branch 5 taken 112 times.
✗ Branch 6 not taken.
✓ Branch 8 taken 112 times.
✗ Branch 9 not taken.
|
224 | pb.response_readers_.emplace_back(stubs_[0]->AsyncGetParameter( |
| 473 | 112 | pb.contexts_[index].get(), request, pb.cqs_.get())); | |
| 474 | 112 | auto& rpc = pb.response_readers_.back(); | |
| 475 | // GetParameter(&context, request, &response); | ||
| 476 |
1/2✓ Branch 2 taken 112 times.
✗ Branch 3 not taken.
|
112 | rpc->Finish(&response, &status, reinterpret_cast<void*>(index)); |
| 477 | } | ||
| 478 | 112 | uint64_t prefetch_id = 0; | |
| 479 | { | ||
| 480 |
1/2✓ Branch 1 taken 112 times.
✗ Branch 2 not taken.
|
112 | std::lock_guard<std::mutex> lk(prefetch_mu_); |
| 481 | 112 | prefetch_id = next_prefetch_id_++; | |
| 482 |
1/2✓ Branch 2 taken 112 times.
✗ Branch 3 not taken.
|
112 | prefetch_batches_.emplace(prefetch_id, std::move(pb)); |
| 483 | 112 | } | |
| 484 | |||
| 485 | 112 | return prefetch_id; | |
| 486 | 112 | } | |
| 487 | |||
| 488 | 4 | bool GRPCParameterClient::IsPrefetchDone(uint64_t prefetch_id) { | |
| 489 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | std::lock_guard<std::mutex> lk(prefetch_mu_); |
| 490 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | auto it = prefetch_batches_.find(prefetch_id); |
| 491 |
1/2✗ Branch 2 not taken.
✓ Branch 3 taken 4 times.
|
4 | if (it == prefetch_batches_.end()) { |
| 492 | ✗ | LOG(ERROR) << "Invalid prefetch_id: " << prefetch_id; | |
| 493 | ✗ | return false; | |
| 494 | } | ||
| 495 | 4 | auto& pb = it->second; | |
| 496 | 4 | int request_num = pb.batch_size_; | |
| 497 | 4 | int get = 0; | |
| 498 | |||
| 499 |
1/2✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
|
4 | if (pb.completed_count_ == pb.batch_size_) { |
| 500 | 4 | return true; | |
| 501 | } | ||
| 502 | |||
| 503 | ✗ | void* got_tag = nullptr; | |
| 504 | ✗ | bool ok = false; | |
| 505 | auto deadline = | ||
| 506 | ✗ | std::chrono::system_clock::now() + std::chrono::milliseconds(0); | |
| 507 | for (;;) { | ||
| 508 | ✗ | auto status = pb.cqs_->AsyncNext(&got_tag, &ok, deadline); | |
| 509 | ✗ | if (status == grpc::CompletionQueue::NextStatus::GOT_EVENT) { | |
| 510 | ✗ | if (unlikely(!ok)) { | |
| 511 | ✗ | LOG(ERROR) << "CompletionQueue returned not ok for prefetch"; | |
| 512 | } | ||
| 513 | ✗ | pb.completed_count_++; | |
| 514 | ✗ | if (pb.completed_count_ == pb.batch_size_) | |
| 515 | ✗ | break; | |
| 516 | deadline = | ||
| 517 | ✗ | std::chrono::system_clock::now() + std::chrono::milliseconds(0); | |
| 518 | ✗ | continue; | |
| 519 | ✗ | } else if (status == grpc::CompletionQueue::NextStatus::TIMEOUT) { | |
| 520 | ✗ | break; | |
| 521 | } else { | ||
| 522 | ✗ | LOG(ERROR) << "CompletionQueue shutdown during prefetch"; | |
| 523 | ✗ | break; | |
| 524 | } | ||
| 525 | ✗ | } | |
| 526 | ✗ | return (pb.completed_count_ == pb.batch_size_); | |
| 527 | 4 | } | |
| 528 | |||
| 529 | 212 | void GRPCParameterClient::WaitForPrefetch(uint64_t prefetch_id) { | |
| 530 |
1/2✓ Branch 1 taken 212 times.
✗ Branch 2 not taken.
|
212 | std::lock_guard<std::mutex> lk(prefetch_mu_); |
| 531 |
1/2✓ Branch 1 taken 212 times.
✗ Branch 2 not taken.
|
212 | auto it = prefetch_batches_.find(prefetch_id); |
| 532 |
1/2✗ Branch 2 not taken.
✓ Branch 3 taken 212 times.
|
212 | if (it == prefetch_batches_.end()) { |
| 533 | ✗ | LOG(ERROR) << "Invalid prefetch_id: " << prefetch_id; | |
| 534 | ✗ | return; | |
| 535 | } | ||
| 536 | 212 | auto& pb = it->second; | |
| 537 | 212 | void* got_tag = nullptr; | |
| 538 | 212 | bool ok = false; | |
| 539 | 212 | int idle_rounds = 0; | |
| 540 | 212 | constexpr auto kPollInterval = std::chrono::milliseconds(200); | |
| 541 | 212 | constexpr int kMaxIdleRounds = 150; // 30s | |
| 542 |
2/2✓ Branch 0 taken 112 times.
✓ Branch 1 taken 212 times.
|
324 | while (pb.completed_count_ < pb.batch_size_) { |
| 543 |
1/2✓ Branch 2 taken 112 times.
✗ Branch 3 not taken.
|
112 | auto deadline = std::chrono::system_clock::now() + kPollInterval; |
| 544 |
1/2✓ Branch 2 taken 112 times.
✗ Branch 3 not taken.
|
112 | auto status = pb.cqs_->AsyncNext(&got_tag, &ok, deadline); |
| 545 |
1/2✓ Branch 0 taken 112 times.
✗ Branch 1 not taken.
|
112 | if (status == grpc::CompletionQueue::NextStatus::GOT_EVENT) { |
| 546 | 112 | idle_rounds = 0; | |
| 547 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 112 times.
|
112 | if (unlikely(!ok)) { |
| 548 | ✗ | LOG(ERROR) << "CompletionQueue returned not ok for prefetch"; | |
| 549 | } | ||
| 550 | 112 | pb.completed_count_++; | |
| 551 | 112 | continue; | |
| 552 | } | ||
| 553 | ✗ | if (status == grpc::CompletionQueue::NextStatus::TIMEOUT) { | |
| 554 | ✗ | idle_rounds++; | |
| 555 | ✗ | if (idle_rounds >= kMaxIdleRounds) { | |
| 556 | ✗ | LOG(ERROR) << "WaitForPrefetch timed out for prefetch_id " | |
| 557 | ✗ | << prefetch_id << ", completed " << pb.completed_count_ | |
| 558 | ✗ | << "/" << pb.batch_size_; | |
| 559 | ✗ | break; | |
| 560 | } | ||
| 561 | ✗ | continue; | |
| 562 | } | ||
| 563 | ✗ | if (status == grpc::CompletionQueue::NextStatus::SHUTDOWN) { | |
| 564 | ✗ | LOG(ERROR) << "CompletionQueue shutdown while waiting prefetch"; | |
| 565 | ✗ | break; | |
| 566 | } | ||
| 567 | } | ||
| 568 |
1/2✓ Branch 1 taken 212 times.
✗ Branch 2 not taken.
|
212 | } |
| 569 | |||
| 570 | 112 | bool GRPCParameterClient::GetPrefetchResult( | |
| 571 | uint64_t prefetch_id, std::vector<std::vector<float>>* values) { | ||
| 572 |
1/2✓ Branch 1 taken 112 times.
✗ Branch 2 not taken.
|
112 | std::lock_guard<std::mutex> lk(prefetch_mu_); |
| 573 |
1/2✓ Branch 1 taken 112 times.
✗ Branch 2 not taken.
|
112 | auto it = prefetch_batches_.find(prefetch_id); |
| 574 |
1/2✗ Branch 2 not taken.
✓ Branch 3 taken 112 times.
|
112 | if (it == prefetch_batches_.end()) { |
| 575 | ✗ | LOG(ERROR) << "Invalid prefetch_id: " << prefetch_id; | |
| 576 | ✗ | return false; | |
| 577 | } | ||
| 578 | 112 | auto& pb = it->second; | |
| 579 | 112 | int request_num = pb.batch_size_; | |
| 580 | |||
| 581 | 112 | values->clear(); | |
| 582 | 112 | int keys_size = 0; | |
| 583 |
2/2✓ Branch 4 taken 112 times.
✓ Branch 5 taken 112 times.
|
224 | for (const auto& size : pb.key_sizes_) { |
| 584 | 112 | keys_size += size; | |
| 585 | } | ||
| 586 |
1/2✓ Branch 1 taken 112 times.
✗ Branch 2 not taken.
|
112 | values->reserve(keys_size); |
| 587 | |||
| 588 |
2/2✓ Branch 0 taken 112 times.
✓ Branch 1 taken 112 times.
|
224 | for (int i = 0; i < request_num; ++i) { |
| 589 | 112 | auto& response = pb.responses_[i]; | |
| 590 | 112 | int key_size = pb.key_sizes_[i]; | |
| 591 | auto parameters = reinterpret_cast<const ParameterCompressReader*>( | ||
| 592 |
1/2✓ Branch 1 taken 112 times.
✗ Branch 2 not taken.
|
112 | response.parameter_value().data()); |
| 593 | |||
| 594 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 112 times.
|
112 | if (unlikely(parameters->size != key_size)) { |
| 595 | ✗ | LOG(ERROR) << "GetParameter error: " << parameters->size << " vs " | |
| 596 | ✗ | << key_size; | |
| 597 | ✗ | return false; | |
| 598 | } | ||
| 599 | |||
| 600 |
2/2✓ Branch 1 taken 520 times.
✓ Branch 2 taken 112 times.
|
632 | for (int index = 0; index < parameters->item_size(); ++index) { |
| 601 |
1/2✓ Branch 1 taken 520 times.
✗ Branch 2 not taken.
|
520 | auto item = parameters->item(index); |
| 602 |
1/2✓ Branch 0 taken 520 times.
✗ Branch 1 not taken.
|
520 | if (item->dim != 0) { |
| 603 |
1/2✓ Branch 1 taken 520 times.
✗ Branch 2 not taken.
|
520 | values->emplace_back( |
| 604 |
1/2✓ Branch 2 taken 520 times.
✗ Branch 3 not taken.
|
1040 | std::vector<float>(item->embedding, item->embedding + item->dim)); |
| 605 | } else { | ||
| 606 | ✗ | values->emplace_back(std::vector<float>(0)); | |
| 607 | } | ||
| 608 | } | ||
| 609 | } | ||
| 610 | |||
| 611 | 112 | return true; | |
| 612 | 112 | } | |
| 613 | |||
| 614 | ✗ | bool GRPCParameterClient::GetPrefetchResultFlat( | |
| 615 | uint64_t prefetch_id, | ||
| 616 | std::vector<float>* values, | ||
| 617 | int64_t* num_rows, | ||
| 618 | int64_t embedding_dim) { | ||
| 619 | ✗ | std::lock_guard<std::mutex> lk(prefetch_mu_); | |
| 620 | ✗ | auto it = prefetch_batches_.find(prefetch_id); | |
| 621 | ✗ | if (it == prefetch_batches_.end()) { | |
| 622 | ✗ | LOG(ERROR) << "Invalid prefetch_id: " << prefetch_id; | |
| 623 | ✗ | return false; | |
| 624 | } | ||
| 625 | ✗ | if (values == nullptr || num_rows == nullptr) { | |
| 626 | ✗ | LOG(ERROR) << "GetPrefetchResultFlat output pointer is null"; | |
| 627 | ✗ | return false; | |
| 628 | } | ||
| 629 | |||
| 630 | ✗ | auto& pb = it->second; | |
| 631 | ✗ | int request_num = pb.batch_size_; | |
| 632 | ✗ | int total_keys = 0; | |
| 633 | ✗ | for (const auto& size : pb.key_sizes_) { | |
| 634 | ✗ | total_keys += size; | |
| 635 | } | ||
| 636 | |||
| 637 | ✗ | *num_rows = static_cast<int64_t>(total_keys); | |
| 638 | ✗ | values->assign( | |
| 639 | ✗ | static_cast<size_t>(*num_rows) * static_cast<size_t>(embedding_dim), | |
| 640 | ✗ | 0.0f); | |
| 641 | |||
| 642 | ✗ | size_t row_offset = 0; | |
| 643 | ✗ | for (int i = 0; i < request_num; ++i) { | |
| 644 | ✗ | auto& response = pb.responses_[i]; | |
| 645 | ✗ | int key_size = pb.key_sizes_[i]; | |
| 646 | auto parameters = reinterpret_cast<const ParameterCompressReader*>( | ||
| 647 | ✗ | response.parameter_value().data()); | |
| 648 | |||
| 649 | ✗ | if (unlikely(parameters->size != key_size)) { | |
| 650 | ✗ | LOG(ERROR) << "GetParameter error: " << parameters->size << " vs " | |
| 651 | ✗ | << key_size; | |
| 652 | ✗ | return false; | |
| 653 | } | ||
| 654 | |||
| 655 | ✗ | for (int index = 0; index < parameters->item_size(); | |
| 656 | ✗ | ++index, ++row_offset) { | |
| 657 | ✗ | auto item = parameters->item(index); | |
| 658 | ✗ | if (item->dim != 0) { | |
| 659 | const int64_t copy_d = | ||
| 660 | ✗ | std::min<int64_t>(embedding_dim, static_cast<int64_t>(item->dim)); | |
| 661 | ✗ | std::memcpy(values->data() + row_offset * embedding_dim, | |
| 662 | ✗ | item->embedding, | |
| 663 | ✗ | static_cast<size_t>(copy_d) * sizeof(float)); | |
| 664 | } | ||
| 665 | } | ||
| 666 | } | ||
| 667 | |||
| 668 | ✗ | return true; | |
| 669 | ✗ | } | |
| 670 | |||
| 671 | 34 | bool GRPCParameterClient::ClearPS() { | |
| 672 |
1/2✓ Branch 1 taken 34 times.
✗ Branch 2 not taken.
|
34 | CommandRequest request; |
| 673 |
1/2✓ Branch 1 taken 34 times.
✗ Branch 2 not taken.
|
34 | CommandResponse response; |
| 674 |
1/2✓ Branch 1 taken 34 times.
✗ Branch 2 not taken.
|
34 | request.set_command(PSCommand::CLEAR_PS); |
| 675 |
1/2✓ Branch 1 taken 34 times.
✗ Branch 2 not taken.
|
34 | grpc::ClientContext context; |
| 676 |
1/2✓ Branch 1 taken 34 times.
✗ Branch 2 not taken.
|
34 | SetRpcDeadline(&context); |
| 677 |
1/2✓ Branch 3 taken 34 times.
✗ Branch 4 not taken.
|
34 | grpc::Status status = stubs_[0]->Command(&context, request, &response); |
| 678 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 34 times.
|
34 | if (!status.ok()) { |
| 679 | ✗ | LOG(ERROR) << "gRPC ClearPS failed: " << status.error_code() << " " | |
| 680 | ✗ | << status.error_message(); | |
| 681 | } | ||
| 682 | 68 | return status.ok(); | |
| 683 | 34 | } | |
| 684 | |||
| 685 | // Read n bytes from the server. The server does not access storage; | ||
| 686 | // it generates data randomly instead. | ||
| 687 | 6 | bool GRPCParameterClient::LoadFakeData(int64_t n) { | |
| 688 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | CommandRequest request; |
| 689 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | CommandResponse response; |
| 690 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | request.set_command(PSCommand::LOAD_FAKE_DATA); |
| 691 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | request.add_arg1(&n, sizeof(int64_t)); |
| 692 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | grpc::ClientContext context; |
| 693 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | SetRpcDeadline(&context); |
| 694 |
1/2✓ Branch 3 taken 6 times.
✗ Branch 4 not taken.
|
6 | grpc::Status status = stubs_[0]->Command(&context, request, &response); |
| 695 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
|
6 | if (!status.ok()) { |
| 696 | ✗ | LOG(ERROR) << "gRPC LoadFakeData failed: " << status.error_code() << " " | |
| 697 | ✗ | << status.error_message(); | |
| 698 | ✗ | return false; | |
| 699 | } | ||
| 700 |
2/4✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 6 times.
|
6 | if (response.reply().size() != static_cast<size_t>(n)) { |
| 701 | ✗ | LOG(ERROR) << "gRPC LoadFakeData reply size mismatch: expected " << n | |
| 702 | ✗ | << ", got " << response.reply().size(); | |
| 703 | ✗ | return false; | |
| 704 | } | ||
| 705 | 6 | return true; | |
| 706 | 6 | } | |
| 707 | |||
| 708 | // Write n bytes(random generated) into the server | ||
| 709 | 6 | bool GRPCParameterClient::DumpFakeData(int64_t n) { | |
| 710 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | CommandRequest request; |
| 711 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | CommandResponse response; |
| 712 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | request.set_command(PSCommand::DUMP_FAKE_DATA); |
| 713 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | request.add_arg1(&n, sizeof(int64_t)); |
| 714 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | grpc::ClientContext context; |
| 715 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | SetRpcDeadline(&context); |
| 716 |
1/2✓ Branch 3 taken 6 times.
✗ Branch 4 not taken.
|
6 | grpc::Status status = stubs_[0]->Command(&context, request, &response); |
| 717 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
|
6 | if (!status.ok()) { |
| 718 | ✗ | LOG(ERROR) << "gRPC DumpFakeData failed: " << status.error_code() << " " | |
| 719 | ✗ | << status.error_message(); | |
| 720 | ✗ | return false; | |
| 721 | } | ||
| 722 |
3/6✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
|
6 | if (response.reply() != "ok") { |
| 723 | ✗ | LOG(ERROR) << "gRPC DumpFakeData unexpected reply: " << response.reply(); | |
| 724 | ✗ | return false; | |
| 725 | } | ||
| 726 | 6 | return true; | |
| 727 | 6 | } | |
| 728 | |||
| 729 | ✗ | bool GRPCParameterClient::LoadCkpt( | |
| 730 | const std::vector<std::string>& model_config_path, | ||
| 731 | const std::vector<std::string>& emb_file_path) { | ||
| 732 | ✗ | CommandRequest request; | |
| 733 | ✗ | CommandResponse response; | |
| 734 | ✗ | request.set_command(PSCommand::RELOAD_PS); | |
| 735 | |||
| 736 | ✗ | for (auto& each : model_config_path) { | |
| 737 | ✗ | request.add_arg1(each); | |
| 738 | } | ||
| 739 | ✗ | for (auto& each : emb_file_path) { | |
| 740 | ✗ | request.add_arg2(each); | |
| 741 | } | ||
| 742 | ✗ | grpc::ClientContext context; | |
| 743 | ✗ | SetRpcDeadline(&context, 30000); | |
| 744 | ✗ | grpc::Status status = stubs_[0]->Command(&context, request, &response); | |
| 745 | ✗ | return status.ok(); | |
| 746 | ✗ | } | |
| 747 | |||
| 748 | 150 | bool GRPCParameterClient::PutParameter( | |
| 749 | const std::vector<uint64_t>& keys, | ||
| 750 | const std::vector<std::vector<float>>& values) { | ||
| 751 |
1/2✗ Branch 2 not taken.
✓ Branch 3 taken 150 times.
|
150 | if (keys.size() != values.size()) { |
| 752 | ✗ | LOG(ERROR) << "PutParameter keys/values size mismatch: " << keys.size() | |
| 753 | ✗ | << " vs " << values.size(); | |
| 754 | ✗ | return false; | |
| 755 | } | ||
| 756 |
2/2✓ Branch 1 taken 150 times.
✓ Branch 2 taken 150 times.
|
300 | for (int start = 0, index = 0; start < keys.size(); |
| 757 | 150 | start += MAX_PARAMETER_BATCH, ++index) { | |
| 758 | 150 | int key_size = std::min((int)(keys.size() - start), MAX_PARAMETER_BATCH); | |
| 759 |
1/2✓ Branch 1 taken 150 times.
✗ Branch 2 not taken.
|
150 | PutParameterRequest request; |
| 760 |
1/2✓ Branch 1 taken 150 times.
✗ Branch 2 not taken.
|
150 | PutParameterResponse response; |
| 761 |
1/2✓ Branch 1 taken 150 times.
✗ Branch 2 not taken.
|
150 | ParameterCompressor compressor; |
| 762 | 150 | std::vector<std::string> blocks; | |
| 763 |
2/2✓ Branch 0 taken 1250 times.
✓ Branch 1 taken 150 times.
|
1400 | for (int i = start; i < start + key_size; i++) { |
| 764 | 1250 | auto each_key = keys[i]; | |
| 765 | 1250 | auto& embedding = values[i]; | |
| 766 | 1250 | ParameterPack parameter_pack; | |
| 767 | 1250 | parameter_pack.key = each_key; | |
| 768 | 1250 | parameter_pack.dim = embedding.size(); | |
| 769 | 1250 | parameter_pack.emb_data = embedding.data(); | |
| 770 |
1/2✓ Branch 1 taken 1250 times.
✗ Branch 2 not taken.
|
1250 | compressor.AddItem(parameter_pack, &blocks); |
| 771 | } | ||
| 772 |
1/2✓ Branch 1 taken 150 times.
✗ Branch 2 not taken.
|
150 | compressor.ToBlock(&blocks); |
| 773 |
2/8✓ Branch 4 taken 150 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 150 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
|
150 | CHECK_EQ(blocks.size(), 1); |
| 774 |
1/2✓ Branch 1 taken 150 times.
✗ Branch 2 not taken.
|
150 | request.mutable_parameter_value()->swap(blocks[0]); |
| 775 |
1/2✓ Branch 1 taken 150 times.
✗ Branch 2 not taken.
|
150 | grpc::ClientContext context; |
| 776 |
1/2✓ Branch 1 taken 150 times.
✗ Branch 2 not taken.
|
150 | SetRpcDeadline(&context); |
| 777 |
1/2✓ Branch 3 taken 150 times.
✗ Branch 4 not taken.
|
150 | grpc::Status status = stubs_[0]->PutParameter(&context, request, &response); |
| 778 |
1/2✓ Branch 1 taken 150 times.
✗ Branch 2 not taken.
|
150 | if (status.ok()) { |
| 779 | 150 | continue; | |
| 780 | } else { | ||
| 781 | ✗ | std::cout << status.error_code() << ": " << status.error_message() | |
| 782 | ✗ | << std::endl; | |
| 783 | ✗ | return false; | |
| 784 | } | ||
| 785 |
6/12✓ Branch 1 taken 150 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 150 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 150 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 150 times.
✗ Branch 11 not taken.
✓ Branch 13 taken 150 times.
✗ Branch 14 not taken.
✓ Branch 16 taken 150 times.
✗ Branch 17 not taken.
|
900 | } |
| 786 | 150 | return true; | |
| 787 | } | ||
| 788 | |||
| 789 | 2 | int GRPCParameterClient::UpdateParameter( | |
| 790 | const std::string& table_name, | ||
| 791 | const base::ConstArray<uint64_t>& keys, | ||
| 792 | const std::vector<std::vector<float>>* grads) { | ||
| 793 | #ifdef ENABLE_PERF_REPORT | ||
| 794 | auto start_time = std::chrono::high_resolution_clock::now(); | ||
| 795 | const uint64_t trace_id = recstore::g_trace_id; | ||
| 796 | #endif | ||
| 797 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
|
2 | if (grads == nullptr) { |
| 798 | ✗ | LOG(ERROR) << "UpdateParameter grads pointer is null"; | |
| 799 | ✗ | return -1; | |
| 800 | } | ||
| 801 |
1/2✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
|
2 | if (keys.Size() != grads->size()) { |
| 802 | ✗ | LOG(ERROR) << "UpdateParameter keys/grads size mismatch: " << keys.Size() | |
| 803 | ✗ | << " vs " << grads->size(); | |
| 804 | ✗ | return -1; | |
| 805 | } | ||
| 806 | |||
| 807 |
1/2✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
|
2 | ParameterCompressor compressor; |
| 808 |
2/2✓ Branch 1 taken 4 times.
✓ Branch 2 taken 2 times.
|
6 | for (size_t i = 0; i < keys.Size(); ++i) { |
| 809 | 4 | ParameterPack pack; | |
| 810 | 4 | pack.key = keys[i]; | |
| 811 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | pack.dim = grads->at(i).size(); |
| 812 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | pack.emb_data = grads->at(i).data(); |
| 813 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | compressor.AddItem(pack, nullptr); |
| 814 | } | ||
| 815 | #ifdef ENABLE_PERF_REPORT | ||
| 816 | auto serialize_done_time = std::chrono::high_resolution_clock::now(); | ||
| 817 | #endif | ||
| 818 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
|
2 | if (keys.Size() == 0) { |
| 819 | ✗ | LOG(WARNING) << "UpdateParameter no gradients to send"; | |
| 820 | ✗ | return 0; | |
| 821 | } | ||
| 822 | |||
| 823 |
1/2✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
|
2 | UpdateParameterRequest request; |
| 824 |
1/2✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
|
2 | UpdateParameterResponse response; |
| 825 | request.set_table_name(table_name); | ||
| 826 |
2/4✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
|
2 | compressor.ToBlock(request.mutable_gradients()); |
| 827 |
2/4✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
|
2 | if (request.gradients().empty()) { |
| 828 | ✗ | LOG(WARNING) << "UpdateParameter no serialized gradients payload"; | |
| 829 | ✗ | return 0; | |
| 830 | } | ||
| 831 | |||
| 832 |
1/2✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
|
2 | grpc::ClientContext context; |
| 833 |
1/2✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
|
2 | SetRpcDeadline(&context); |
| 834 | #ifdef ENABLE_PERF_REPORT | ||
| 835 | if (trace_id != 0) { | ||
| 836 | context.AddMetadata("x-recstore-trace-id", std::to_string(trace_id)); | ||
| 837 | } | ||
| 838 | auto rpc_start_time = std::chrono::high_resolution_clock::now(); | ||
| 839 | #endif | ||
| 840 | grpc::Status status = | ||
| 841 |
1/2✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
|
2 | stubs_[0]->UpdateParameter(&context, request, &response); |
| 842 | #ifdef ENABLE_PERF_REPORT | ||
| 843 | auto end_time = std::chrono::high_resolution_clock::now(); | ||
| 844 | auto serialize_duration = | ||
| 845 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 846 | serialize_done_time - start_time) | ||
| 847 | .count(); | ||
| 848 | auto rpc_duration = | ||
| 849 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 850 | end_time - rpc_start_time) | ||
| 851 | .count(); | ||
| 852 | auto total_duration = | ||
| 853 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 854 | end_time - start_time) | ||
| 855 | .count(); | ||
| 856 | std::string stage_id = | ||
| 857 | "grpc_client::EmbUpdate|" + | ||
| 858 | std::to_string( | ||
| 859 | trace_id == 0 | ||
| 860 | ? static_cast<uint64_t>( | ||
| 861 | std::chrono::duration_cast< std::chrono::microseconds>( | ||
| 862 | start_time.time_since_epoch()) | ||
| 863 | .count()) | ||
| 864 | : trace_id); | ||
| 865 | report("embupdate_stages", | ||
| 866 | stage_id.c_str(), | ||
| 867 | "client_serialize_us", | ||
| 868 | static_cast<double>(serialize_duration)); | ||
| 869 | report("embupdate_stages", | ||
| 870 | stage_id.c_str(), | ||
| 871 | "client_rpc_us", | ||
| 872 | static_cast<double>(rpc_duration)); | ||
| 873 | report("embupdate_stages", | ||
| 874 | stage_id.c_str(), | ||
| 875 | "client_total_us", | ||
| 876 | static_cast<double>(total_duration)); | ||
| 877 | report("embupdate_stages", | ||
| 878 | stage_id.c_str(), | ||
| 879 | "client_request_size", | ||
| 880 | static_cast<double>(keys.Size())); | ||
| 881 | #endif | ||
| 882 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
|
2 | if (!status.ok()) { |
| 883 | ✗ | LOG(ERROR) << "UpdateParameter RPC failed: " << status.error_message(); | |
| 884 | ✗ | return -1; | |
| 885 | } | ||
| 886 |
2/4✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
|
2 | return response.success() ? 0 : -1; |
| 887 | 2 | } | |
| 888 | |||
| 889 | ✗ | int GRPCParameterClient::UpdateParameterFlat( | |
| 890 | const std::string& table_name, | ||
| 891 | const base::ConstArray<uint64_t>& keys, | ||
| 892 | const float* grads, | ||
| 893 | int64_t num_rows, | ||
| 894 | int64_t embedding_dim) { | ||
| 895 | #ifdef ENABLE_PERF_REPORT | ||
| 896 | auto start_time = std::chrono::high_resolution_clock::now(); | ||
| 897 | const uint64_t trace_id = recstore::g_trace_id; | ||
| 898 | #endif | ||
| 899 | ✗ | if (keys.Size() == 0) { | |
| 900 | ✗ | return 0; | |
| 901 | } | ||
| 902 | |||
| 903 | ✗ | ParameterCompressor compressor; | |
| 904 | ✗ | if (BuildUpdateBlocksFromFlat( | |
| 905 | ✗ | keys, grads, num_rows, embedding_dim, &compressor) != 0) { | |
| 906 | ✗ | return -1; | |
| 907 | } | ||
| 908 | #ifdef ENABLE_PERF_REPORT | ||
| 909 | auto serialize_done_time = std::chrono::high_resolution_clock::now(); | ||
| 910 | #endif | ||
| 911 | ✗ | UpdateParameterRequest request; | |
| 912 | ✗ | UpdateParameterResponse response; | |
| 913 | request.set_table_name(table_name); | ||
| 914 | ✗ | compressor.ToBlock(request.mutable_gradients()); | |
| 915 | ✗ | if (request.gradients().empty()) { | |
| 916 | ✗ | return 0; | |
| 917 | } | ||
| 918 | |||
| 919 | ✗ | grpc::ClientContext context; | |
| 920 | ✗ | SetRpcDeadline(&context); | |
| 921 | #ifdef ENABLE_PERF_REPORT | ||
| 922 | if (trace_id != 0) { | ||
| 923 | context.AddMetadata("x-recstore-trace-id", std::to_string(trace_id)); | ||
| 924 | } | ||
| 925 | auto rpc_start_time = std::chrono::high_resolution_clock::now(); | ||
| 926 | #endif | ||
| 927 | grpc::Status status = | ||
| 928 | ✗ | stubs_[0]->UpdateParameter(&context, request, &response); | |
| 929 | #ifdef ENABLE_PERF_REPORT | ||
| 930 | auto end_time = std::chrono::high_resolution_clock::now(); | ||
| 931 | auto serialize_duration = | ||
| 932 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 933 | serialize_done_time - start_time) | ||
| 934 | .count(); | ||
| 935 | auto rpc_duration = | ||
| 936 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 937 | end_time - rpc_start_time) | ||
| 938 | .count(); | ||
| 939 | auto total_duration = | ||
| 940 | std::chrono::duration_cast<std::chrono::microseconds>( | ||
| 941 | end_time - start_time) | ||
| 942 | .count(); | ||
| 943 | std::string stage_id = | ||
| 944 | "grpc_client::EmbUpdate|" + | ||
| 945 | std::to_string( | ||
| 946 | trace_id == 0 | ||
| 947 | ? static_cast<uint64_t>( | ||
| 948 | std::chrono::duration_cast< std::chrono::microseconds>( | ||
| 949 | start_time.time_since_epoch()) | ||
| 950 | .count()) | ||
| 951 | : trace_id); | ||
| 952 | report("embupdate_stages", | ||
| 953 | stage_id.c_str(), | ||
| 954 | "client_serialize_us", | ||
| 955 | static_cast<double>(serialize_duration)); | ||
| 956 | report("embupdate_stages", | ||
| 957 | stage_id.c_str(), | ||
| 958 | "client_rpc_us", | ||
| 959 | static_cast<double>(rpc_duration)); | ||
| 960 | report("embupdate_stages", | ||
| 961 | stage_id.c_str(), | ||
| 962 | "client_total_us", | ||
| 963 | static_cast<double>(total_duration)); | ||
| 964 | report("embupdate_stages", | ||
| 965 | stage_id.c_str(), | ||
| 966 | "client_request_size", | ||
| 967 | static_cast<double>(num_rows)); | ||
| 968 | report("embupdate_stages", | ||
| 969 | stage_id.c_str(), | ||
| 970 | "client_embedding_dim", | ||
| 971 | static_cast<double>(embedding_dim)); | ||
| 972 | #endif | ||
| 973 | ✗ | if (!status.ok()) { | |
| 974 | ✗ | LOG(ERROR) << "UpdateParameterFlat RPC failed: " << status.error_message(); | |
| 975 | ✗ | return -1; | |
| 976 | } | ||
| 977 | ✗ | return response.success() ? 0 : -1; | |
| 978 | ✗ | } | |
| 979 | |||
| 980 | 20 | int GRPCParameterClient::InitEmbeddingTable( | |
| 981 | const std::string& table_name, | ||
| 982 | const recstore::EmbeddingTableConfig& config) { | ||
| 983 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | InitEmbeddingTableRequest request; |
| 984 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | InitEmbeddingTableResponse response; |
| 985 | request.set_table_name(table_name); | ||
| 986 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
40 | request.set_config_payload(config.Serialize()); |
| 987 | |||
| 988 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | grpc::ClientContext context; |
| 989 | grpc::Status status = | ||
| 990 |
1/2✓ Branch 3 taken 20 times.
✗ Branch 4 not taken.
|
20 | stubs_[0]->InitEmbeddingTable(&context, request, &response); |
| 991 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 20 times.
|
20 | if (!status.ok()) { |
| 992 | ✗ | LOG(ERROR) << "InitEmbeddingTable RPC failed: " << status.error_message(); | |
| 993 | ✗ | return -1; | |
| 994 | } | ||
| 995 |
2/4✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 20 times.
✗ Branch 4 not taken.
|
20 | return response.success() ? 0 : -1; |
| 996 | 20 | } | |
| 997 | |||
| 998 | // BasePSClient pure virtual implementations | ||
| 999 | // int GRPCParameterClient::GetParameter(const base::ConstArray<uint64_t>& keys, | ||
| 1000 | // float* values) { | ||
| 1001 | // return GetParameter(ConstArray<uint64_t>(keys.Data(), keys.Size()), values) | ||
| 1002 | // ? 0 : -1; | ||
| 1003 | // } | ||
| 1004 | |||
| 1005 | ✗ | int GRPCParameterClient::AsyncGetParameter( | |
| 1006 | const base::ConstArray<uint64_t>& keys, float* values) { | ||
| 1007 | ✗ | return GetParameter(keys, values); | |
| 1008 | } | ||
| 1009 | |||
| 1010 | 148 | int GRPCParameterClient::PutParameter( | |
| 1011 | const base::ConstArray<uint64_t>& keys, | ||
| 1012 | const std::vector<std::vector<float>>& values) { | ||
| 1013 |
1/2✓ Branch 5 taken 148 times.
✗ Branch 6 not taken.
|
148 | std::vector<uint64_t> key_vec(keys.Data(), keys.Data() + keys.Size()); |
| 1014 |
1/2✓ Branch 1 taken 148 times.
✗ Branch 2 not taken.
|
148 | bool success = PutParameter(key_vec, values); |
| 1015 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 148 times.
|
148 | if (!success) { |
| 1016 | ✗ | LOG(ERROR) << "PutParameter batch failed"; | |
| 1017 | } | ||
| 1018 |
1/2✓ Branch 0 taken 148 times.
✗ Branch 1 not taken.
|
296 | return success ? 1 : 0; |
| 1019 | 148 | } | |
| 1020 | |||
| 1021 | ✗ | void GRPCParameterClient::Command(recstore::PSCommand command) { | |
| 1022 | ✗ | switch (command) { | |
| 1023 | ✗ | case recstore::PSCommand::CLEAR_PS: | |
| 1024 | ✗ | ClearPS(); | |
| 1025 | ✗ | break; | |
| 1026 | ✗ | case recstore::PSCommand::RELOAD_PS: | |
| 1027 | |||
| 1028 | ✗ | LOG(WARNING) << "RELOAD_PS command requires additional parameters"; | |
| 1029 | ✗ | break; | |
| 1030 | ✗ | case recstore::PSCommand::LOAD_FAKE_DATA: { | |
| 1031 | ✗ | int64_t fake_data = 1000; | |
| 1032 | ✗ | LoadFakeData(fake_data); | |
| 1033 | ✗ | } break; | |
| 1034 | ✗ | case recstore::PSCommand::DUMP_FAKE_DATA: { | |
| 1035 | ✗ | DumpFakeData(4096); | |
| 1036 | ✗ | } break; | |
| 1037 | ✗ | default: | |
| 1038 | ✗ | LOG(ERROR) << "Unknown PS command: " << static_cast<int>(command); | |
| 1039 | ✗ | break; | |
| 1040 | } | ||
| 1041 | ✗ | } | |
| 1042 | |||
| 1043 | 8 | uint64_t GRPCParameterClient::EmbWriteAsync( | |
| 1044 | const base::ConstArray<uint64_t>& keys, | ||
| 1045 | const std::vector<std::vector<float>>& values) { | ||
| 1046 | 8 | uint64_t prewrite_id = next_prewrite_id_++; | |
| 1047 | int request_num = | ||
| 1048 | 8 | (keys.Size() + MAX_PARAMETER_BATCH - 1) / MAX_PARAMETER_BATCH; | |
| 1049 | |||
| 1050 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | struct PrewriteBatch pb(request_num); |
| 1051 | |||
| 1052 |
2/2✓ Branch 1 taken 8 times.
✓ Branch 2 taken 8 times.
|
16 | for (int start = 0, index = 0; start < keys.Size(); |
| 1053 | 8 | start += MAX_PARAMETER_BATCH, ++index) { | |
| 1054 | 8 | int key_size = std::min((int)(keys.Size() - start), MAX_PARAMETER_BATCH); | |
| 1055 | 8 | pb.key_sizes_[index] = key_size; | |
| 1056 | 8 | auto& status = pb.status_[index]; | |
| 1057 |
1/2✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
|
8 | if (!pb.contexts_[index]) { |
| 1058 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | pb.contexts_[index] = std::make_unique<grpc::ClientContext>(); |
| 1059 | } | ||
| 1060 | 8 | auto& request = pb.requests_[index]; | |
| 1061 | 8 | auto& response = pb.responses_[index]; | |
| 1062 | |||
| 1063 | // Pack key/embedding pairs | ||
| 1064 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | ParameterCompressor compressor; |
| 1065 | 8 | std::vector<std::string> blocks; | |
| 1066 |
2/2✓ Branch 0 taken 96 times.
✓ Branch 1 taken 8 times.
|
104 | for (int i = start; i < start + key_size; i++) { |
| 1067 | 96 | auto each_key = keys[i]; | |
| 1068 | 96 | auto& embedding = values[i]; | |
| 1069 | 96 | ParameterPack parameter_pack; | |
| 1070 | 96 | parameter_pack.key = each_key; | |
| 1071 | 96 | parameter_pack.dim = embedding.size(); | |
| 1072 | 96 | parameter_pack.emb_data = embedding.data(); | |
| 1073 |
1/2✓ Branch 1 taken 96 times.
✗ Branch 2 not taken.
|
96 | compressor.AddItem(parameter_pack, &blocks); |
| 1074 | } | ||
| 1075 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | compressor.ToBlock(&blocks); |
| 1076 |
2/8✓ Branch 4 taken 8 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 8 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
|
8 | CHECK_EQ(blocks.size(), 1); |
| 1077 | |||
| 1078 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | request.mutable_parameter_value()->swap(blocks[0]); |
| 1079 | |||
| 1080 | // Issue async RPC | ||
| 1081 |
2/4✓ Branch 5 taken 8 times.
✗ Branch 6 not taken.
✓ Branch 8 taken 8 times.
✗ Branch 9 not taken.
|
16 | pb.response_readers_.emplace_back(stubs_[0]->AsyncPutParameter( |
| 1082 | 8 | pb.contexts_[index].get(), request, pb.cqs_.get())); | |
| 1083 | |||
| 1084 | // Async call; completion via CQ tag | ||
| 1085 | 8 | auto& rpc = pb.response_readers_.back(); | |
| 1086 |
1/2✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
|
8 | rpc->Finish(&response, &status, reinterpret_cast<void*>(index)); |
| 1087 | 8 | } | |
| 1088 | |||
| 1089 | // Store batch state in prewrite_batches_ | ||
| 1090 |
1/2✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
|
8 | prewrite_batches_.emplace(prewrite_id, std::move(pb)); |
| 1091 | 8 | return prewrite_id; | |
| 1092 | 8 | } | |
| 1093 | |||
| 1094 | ✗ | bool GRPCParameterClient::IsWriteDone(uint64_t write_id) { | |
| 1095 | ✗ | LOG(ERROR) << "IsWriteDone not implemented!"; | |
| 1096 | ✗ | auto it = prewrite_batches_.find(write_id); | |
| 1097 | ✗ | if (it == prewrite_batches_.end()) { | |
| 1098 | ✗ | LOG(ERROR) << "Invalid prewrite_id: " << write_id; | |
| 1099 | ✗ | return false; | |
| 1100 | } | ||
| 1101 | ✗ | auto& pb = it->second; | |
| 1102 | ✗ | return (pb.completed_count_ == pb.batch_size_); | |
| 1103 | } | ||
| 1104 | |||
| 1105 | 8 | void GRPCParameterClient::WaitForWrite(uint64_t write_id) { | |
| 1106 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | auto it = prewrite_batches_.find(write_id); |
| 1107 |
1/2✗ Branch 2 not taken.
✓ Branch 3 taken 8 times.
|
8 | if (it == prewrite_batches_.end()) { |
| 1108 | ✗ | LOG(ERROR) << "Invalid prewrite_id: " << write_id; | |
| 1109 | ✗ | return; | |
| 1110 | } | ||
| 1111 | 8 | auto& pb = it->second; | |
| 1112 |
2/2✓ Branch 0 taken 8 times.
✓ Branch 1 taken 8 times.
|
16 | while (pb.completed_count_ < pb.batch_size_) { |
| 1113 | 8 | void* got_tag = nullptr; | |
| 1114 | 8 | bool ok = false; | |
| 1115 |
2/4✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 8 times.
|
8 | if (!pb.cqs_->Next(&got_tag, &ok)) { |
| 1116 | ✗ | break; | |
| 1117 | } | ||
| 1118 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
|
8 | if (unlikely(!ok)) { |
| 1119 | ✗ | LOG(ERROR) << "Completion queue returned not ok for write"; | |
| 1120 | ✗ | continue; | |
| 1121 | } | ||
| 1122 | 8 | pb.completed_count_++; | |
| 1123 | } | ||
| 1124 | } | ||
| 1125 | |||
| 1126 | // Register GRPCParameterClient with the factory | ||
| 1127 | using BasePSClient = recstore::BasePSClient; | ||
| 1128 | FACTORY_REGISTER(BasePSClient, grpc, GRPCParameterClient, json); | ||
| 1129 |