optimizer/optimizer.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #include "optimizer.h" | ||
| 2 | #include "ps/local_shm/local_shm_stage_report.h" | ||
| 3 | #include <cstring> | ||
| 4 | |||
| 5 | namespace { | ||
| 6 | |||
| 7 | ✗ | std::vector<uint64_t> CollectReaderKeys(const ParameterCompressReader* reader) { | |
| 8 | ✗ | const int size = reader->item_size(); | |
| 9 | ✗ | std::vector<uint64_t> keys; | |
| 10 | ✗ | keys.reserve(size); | |
| 11 | ✗ | for (int i = 0; i < size; ++i) { | |
| 12 | ✗ | keys.push_back(reader->item(i)->key); | |
| 13 | } | ||
| 14 | ✗ | return keys; | |
| 15 | ✗ | } | |
| 16 | |||
| 17 | 6 | void ValidateFlatUpdateArgs(const base::ConstArray<uint64_t>& keys, | |
| 18 | const float* grads, | ||
| 19 | int64_t num_rows, | ||
| 20 | int64_t embedding_dim) { | ||
| 21 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
|
6 | if (grads == nullptr) { |
| 22 | ✗ | throw std::runtime_error("UpdateFlat grads pointer is null"); | |
| 23 | } | ||
| 24 |
2/4✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 6 times.
|
6 | if (num_rows < 0 || embedding_dim <= 0) { |
| 25 | ✗ | throw std::runtime_error("UpdateFlat invalid rows/dim"); | |
| 26 | } | ||
| 27 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
|
6 | if (keys.Size() != static_cast<size_t>(num_rows)) { |
| 28 | ✗ | throw std::runtime_error("UpdateFlat keys size mismatch"); | |
| 29 | } | ||
| 30 | 6 | } | |
| 31 | |||
| 32 | } // namespace | ||
| 33 | |||
| 34 | 24 | void SGD::Init(const std::vector<std::string> table_name, | |
| 35 | const EmbeddingTableConfig& config, | ||
| 36 | BaseKV* base_kv) { | ||
| 37 |
4/8✓ Branch 2 taken 24 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 24 times.
✗ Branch 6 not taken.
✓ Branch 9 taken 24 times.
✗ Branch 10 not taken.
✓ Branch 12 taken 24 times.
✗ Branch 13 not taken.
|
24 | LOG(INFO) << "SGD::Init called with " << table_name.size() << " table(s)"; |
| 38 |
2/2✓ Branch 5 taken 24 times.
✓ Branch 6 taken 24 times.
|
48 | for (const auto& name : table_name) { |
| 39 |
5/10✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 24 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 24 times.
✗ Branch 11 not taken.
✓ Branch 13 taken 24 times.
✗ Branch 14 not taken.
|
48 | LOG(INFO) << " Initializing table: '" << name << "' with shape [" |
| 40 |
4/8✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 24 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 24 times.
✗ Branch 11 not taken.
|
24 | << config.num_embeddings << ", " << config.embedding_dim << "]"; |
| 41 |
1/2✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
|
24 | SparseTensor* param_tensor = new SparseTensor(); |
| 42 |
1/2✓ Branch 2 taken 24 times.
✗ Branch 3 not taken.
|
24 | std::vector<uint64_t> shape = {config.num_embeddings, config.embedding_dim}; |
| 43 | 24 | TAG_TYPE tag = 0; // PARAMETER tag | |
| 44 |
1/2✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
|
24 | param_tensor->init( |
| 45 | const_cast<std::string&>(name), PARAMETER, tag, shape, base_kv); | ||
| 46 |
1/2✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
|
24 | tensor_map_[name] = param_tensor; |
| 47 | 24 | } | |
| 48 |
3/6✓ Branch 2 taken 24 times.
✗ Branch 3 not taken.
✓ Branch 5 taken 24 times.
✗ Branch 6 not taken.
✓ Branch 9 taken 24 times.
✗ Branch 10 not taken.
|
48 | LOG(INFO) << "SGD::Init completed. tensor_map_ now has " << tensor_map_.size() |
| 49 |
1/2✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
|
24 | << " entries"; |
| 50 | 24 | } | |
| 51 | |||
| 52 | ✗ | void SGD::Update( | |
| 53 | std::string table, const ParameterCompressReader* reader, unsigned tid) { | ||
| 54 | ✗ | auto it = tensor_map_.find(table); | |
| 55 | ✗ | if (it == tensor_map_.end()) { | |
| 56 | ✗ | LOG(ERROR) << "Table not found in SGD optimizer: '" << table << "'"; | |
| 57 | ✗ | throw std::runtime_error("Table not found: " + table); | |
| 58 | } | ||
| 59 | |||
| 60 | ✗ | int size = reader->item_size(); | |
| 61 | ✗ | std::vector<uint64_t> keys = CollectReaderKeys(reader); | |
| 62 | |||
| 63 | ✗ | std::vector<base::ConstArray<float>> current_values; | |
| 64 | ✗ | it->second->BatchGet(keys, ¤t_values, tid); | |
| 65 | |||
| 66 | ✗ | for (int i = 0; i < size; ++i) { | |
| 67 | ✗ | const auto* item = reader->item(i); | |
| 68 | ✗ | if (current_values[i].Size() == 0) { | |
| 69 | // If key not found, we fallback to Put to initialize it | ||
| 70 | ✗ | std::vector<float> zero_init(item->dim, 0.0f); | |
| 71 | ✗ | for (int j = 0; j < item->dim; ++j) { | |
| 72 | ✗ | zero_init[j] = -learning_rate_ * item->data()[j]; | |
| 73 | } | ||
| 74 | std::string val_str( | ||
| 75 | ✗ | (char*)zero_init.data(), zero_init.size() * sizeof(float)); | |
| 76 | ✗ | it->second->Put(item->key, val_str, tid); | |
| 77 | ✗ | continue; | |
| 78 | ✗ | } | |
| 79 | |||
| 80 | ✗ | float* data = const_cast<float*>(current_values[i].Data()); | |
| 81 | ✗ | int dim = std::min(current_values[i].Size(), item->dim); | |
| 82 | |||
| 83 | ✗ | #pragma omp simd | |
| 84 | for (int j = 0; j < dim; ++j) { | ||
| 85 | ✗ | data[j] -= learning_rate_ * item->data()[j]; | |
| 86 | } | ||
| 87 | } | ||
| 88 | ✗ | } | |
| 89 | |||
| 90 | 6 | void SGD::UpdateFlat( | |
| 91 | std::string table, | ||
| 92 | const base::ConstArray<uint64_t>& keys, | ||
| 93 | const float* grads, | ||
| 94 | int64_t num_rows, | ||
| 95 | int64_t embedding_dim, | ||
| 96 | unsigned tid) { | ||
| 97 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | ValidateFlatUpdateArgs(keys, grads, num_rows, embedding_dim); |
| 98 | |||
| 99 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | auto it = tensor_map_.find(table); |
| 100 |
1/2✗ Branch 2 not taken.
✓ Branch 3 taken 6 times.
|
6 | if (it == tensor_map_.end()) { |
| 101 | ✗ | LOG(ERROR) << "Table not found in SGD optimizer: '" << table << "'"; | |
| 102 | ✗ | throw std::runtime_error("Table not found: " + table); | |
| 103 | } | ||
| 104 |
2/4✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 6 times.
|
6 | if (it->second->EmbeddingDim() != embedding_dim) { |
| 105 | ✗ | throw std::runtime_error( | |
| 106 | ✗ | "SGD::UpdateFlat embedding_dim mismatch for table " + table); | |
| 107 | } | ||
| 108 | |||
| 109 | 6 | const auto direct_update_start = std::chrono::steady_clock::now(); | |
| 110 |
2/4✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 5 not taken.
|
6 | if (it->second->ApplySgdUpdateFlat( |
| 111 | keys, grads, num_rows, embedding_dim, learning_rate_, tid)) { | ||
| 112 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | recstore::ReportLocalShmStageMetric( |
| 113 | "sgd_update_direct_us", | ||
| 114 | recstore::LocalShmElapsedUs(direct_update_start)); | ||
| 115 | 6 | return; | |
| 116 | } | ||
| 117 | |||
| 118 | ✗ | std::vector<uint64_t> key_vec(keys.Data(), keys.Data() + keys.Size()); | |
| 119 | ✗ | const auto batch_get_start = std::chrono::steady_clock::now(); | |
| 120 | ✗ | std::vector<base::ConstArray<float>> current_values; | |
| 121 | ✗ | it->second->BatchGet(key_vec, ¤t_values, tid); | |
| 122 | ✗ | recstore::ReportLocalShmStageMetric( | |
| 123 | "sgd_update_batch_get_us", recstore::LocalShmElapsedUs(batch_get_start)); | ||
| 124 | |||
| 125 | ✗ | const auto apply_start = std::chrono::steady_clock::now(); | |
| 126 | ✗ | int64_t missing_rows = 0; | |
| 127 | ✗ | for (int64_t row = 0; row < num_rows; ++row) { | |
| 128 | ✗ | const float* row_grad = grads + row * embedding_dim; | |
| 129 | ✗ | const auto& current = current_values[static_cast<size_t>(row)]; | |
| 130 | ✗ | if (current.Size() == 0) { | |
| 131 | ✗ | ++missing_rows; | |
| 132 | ✗ | std::vector<float> zero_init(static_cast<size_t>(embedding_dim), 0.0f); | |
| 133 | ✗ | for (int64_t col = 0; col < embedding_dim; ++col) { | |
| 134 | ✗ | zero_init[static_cast<size_t>(col)] = -learning_rate_ * row_grad[col]; | |
| 135 | } | ||
| 136 | ✗ | std::string val_str(reinterpret_cast<char*>(zero_init.data()), | |
| 137 | ✗ | zero_init.size() * sizeof(float)); | |
| 138 | ✗ | it->second->Put(keys[static_cast<size_t>(row)], val_str, tid); | |
| 139 | ✗ | continue; | |
| 140 | ✗ | } | |
| 141 | ✗ | if (static_cast<int64_t>(current.Size()) != embedding_dim) { | |
| 142 | ✗ | throw std::runtime_error( | |
| 143 | ✗ | "SGD::UpdateFlat embedding_dim mismatch for table " + table); | |
| 144 | } | ||
| 145 | |||
| 146 | ✗ | float* data = const_cast<float*>(current.Data()); | |
| 147 | ✗ | #pragma omp simd | |
| 148 | for (int64_t col = 0; col < embedding_dim; ++col) { | ||
| 149 | ✗ | data[col] -= learning_rate_ * row_grad[col]; | |
| 150 | } | ||
| 151 | } | ||
| 152 | ✗ | recstore::ReportLocalShmStageMetric( | |
| 153 | "sgd_update_apply_us", recstore::LocalShmElapsedUs(apply_start)); | ||
| 154 | ✗ | recstore::ReportLocalShmStageMetric( | |
| 155 | "sgd_update_missing_rows", static_cast<double>(missing_rows)); | ||
| 156 | ✗ | } | |
| 157 | |||
| 158 | ✗ | void AdaGrad::Init(const std::vector<std::string> table_name, | |
| 159 | const EmbeddingTableConfig& config, | ||
| 160 | BaseKV* base_kv) { | ||
| 161 | ✗ | for (const auto& name : table_name) { | |
| 162 | ✗ | SparseTensor* param_tensor = new SparseTensor(); | |
| 163 | ✗ | std::vector<uint64_t> shape = {config.num_embeddings, config.embedding_dim}; | |
| 164 | ✗ | TAG_TYPE tag = 0; | |
| 165 | ✗ | param_tensor->init( | |
| 166 | const_cast<std::string&>(name), PARAMETER, tag, shape, base_kv); | ||
| 167 | ✗ | tensor_map_[name] = param_tensor; | |
| 168 | |||
| 169 | ✗ | std::string acc_table_name = name + "_accumulated_grad"; | |
| 170 | ✗ | SparseTensor* acc_tensor = new SparseTensor(); | |
| 171 | ✗ | acc_tensor->init( | |
| 172 | const_cast<std::string&>(acc_table_name), | ||
| 173 | MOMENT_1, | ||
| 174 | tag, | ||
| 175 | shape, | ||
| 176 | base_kv); | ||
| 177 | ✗ | tensor_map_[acc_table_name] = acc_tensor; | |
| 178 | ✗ | } | |
| 179 | ✗ | } | |
| 180 | |||
| 181 | ✗ | void AdaGrad::Update( | |
| 182 | std::string table, const ParameterCompressReader* reader, unsigned tid) { | ||
| 183 | ✗ | auto param_it = tensor_map_.find(table); | |
| 184 | ✗ | if (param_it == tensor_map_.end()) { | |
| 185 | ✗ | throw std::runtime_error("Table not found: " + table); | |
| 186 | } | ||
| 187 | |||
| 188 | ✗ | std::string acc_table = table + "_accumulated_grad"; | |
| 189 | ✗ | auto acc_it = tensor_map_.find(acc_table); | |
| 190 | ✗ | if (acc_it == tensor_map_.end()) { | |
| 191 | ✗ | throw std::runtime_error( | |
| 192 | ✗ | "Accumulated gradient table not found: " + acc_table); | |
| 193 | } | ||
| 194 | |||
| 195 | ✗ | int size = reader->item_size(); | |
| 196 | ✗ | std::vector<uint64_t> keys = CollectReaderKeys(reader); | |
| 197 | |||
| 198 | ✗ | std::vector<base::ConstArray<float>> current_values; | |
| 199 | ✗ | std::vector<base::ConstArray<float>> acc_values; | |
| 200 | ✗ | param_it->second->BatchGet(keys, ¤t_values, tid); | |
| 201 | ✗ | acc_it->second->BatchGet(keys, &acc_values, tid); | |
| 202 | |||
| 203 | ✗ | for (int i = 0; i < size; ++i) { | |
| 204 | ✗ | const auto* item = reader->item(i); | |
| 205 | ✗ | if (current_values[i].Size() == 0 || acc_values[i].Size() == 0) { | |
| 206 | // Fallback to sequential initialization if not found | ||
| 207 | // (This is rare in training but kept for robustness) | ||
| 208 | ✗ | continue; | |
| 209 | } | ||
| 210 | |||
| 211 | ✗ | float* param_data = const_cast<float*>(current_values[i].Data()); | |
| 212 | ✗ | float* acc_data = const_cast<float*>(acc_values[i].Data()); | |
| 213 | ✗ | int dim = std::min(current_values[i].Size(), item->dim); | |
| 214 | |||
| 215 | ✗ | #pragma omp simd | |
| 216 | for (int j = 0; j < dim; ++j) { | ||
| 217 | ✗ | acc_data[j] += item->data()[j] * item->data()[j]; | |
| 218 | ✗ | float adaptive_lr = learning_rate_ / (std::sqrt(acc_data[j]) + epsilon_); | |
| 219 | ✗ | param_data[j] -= adaptive_lr * item->data()[j]; | |
| 220 | } | ||
| 221 | } | ||
| 222 | ✗ | } | |
| 223 | |||
| 224 | ✗ | void AdaGrad::UpdateFlat( | |
| 225 | std::string table, | ||
| 226 | const base::ConstArray<uint64_t>& keys, | ||
| 227 | const float* grads, | ||
| 228 | int64_t num_rows, | ||
| 229 | int64_t embedding_dim, | ||
| 230 | unsigned tid) { | ||
| 231 | ✗ | ValidateFlatUpdateArgs(keys, grads, num_rows, embedding_dim); | |
| 232 | |||
| 233 | ✗ | auto param_it = tensor_map_.find(table); | |
| 234 | ✗ | if (param_it == tensor_map_.end()) { | |
| 235 | ✗ | throw std::runtime_error("Table not found: " + table); | |
| 236 | } | ||
| 237 | |||
| 238 | ✗ | std::string acc_table = table + "_accumulated_grad"; | |
| 239 | ✗ | auto acc_it = tensor_map_.find(acc_table); | |
| 240 | ✗ | if (acc_it == tensor_map_.end()) { | |
| 241 | ✗ | throw std::runtime_error( | |
| 242 | ✗ | "Accumulated gradient table not found: " + acc_table); | |
| 243 | } | ||
| 244 | |||
| 245 | ✗ | std::vector<uint64_t> key_vec(keys.Data(), keys.Data() + keys.Size()); | |
| 246 | ✗ | std::vector<base::ConstArray<float>> current_values; | |
| 247 | ✗ | std::vector<base::ConstArray<float>> acc_values; | |
| 248 | ✗ | param_it->second->BatchGet(key_vec, ¤t_values, tid); | |
| 249 | ✗ | acc_it->second->BatchGet(key_vec, &acc_values, tid); | |
| 250 | |||
| 251 | ✗ | for (int64_t row = 0; row < num_rows; ++row) { | |
| 252 | ✗ | const auto& current = current_values[static_cast<size_t>(row)]; | |
| 253 | ✗ | const auto& acc = acc_values[static_cast<size_t>(row)]; | |
| 254 | ✗ | if (current.Size() == 0 || acc.Size() == 0) { | |
| 255 | ✗ | continue; | |
| 256 | } | ||
| 257 | ✗ | if (static_cast<int64_t>(current.Size()) != embedding_dim || | |
| 258 | ✗ | static_cast<int64_t>(acc.Size()) != embedding_dim) { | |
| 259 | ✗ | throw std::runtime_error( | |
| 260 | ✗ | "AdaGrad::UpdateFlat embedding_dim mismatch for table " + table); | |
| 261 | } | ||
| 262 | |||
| 263 | ✗ | const float* row_grad = grads + row * embedding_dim; | |
| 264 | ✗ | float* param_data = const_cast<float*>(current.Data()); | |
| 265 | ✗ | float* acc_data = const_cast<float*>(acc.Data()); | |
| 266 | ✗ | #pragma omp simd | |
| 267 | for (int64_t col = 0; col < embedding_dim; ++col) { | ||
| 268 | ✗ | acc_data[col] += row_grad[col] * row_grad[col]; | |
| 269 | float adaptive_lr = | ||
| 270 | ✗ | learning_rate_ / (std::sqrt(acc_data[col]) + epsilon_); | |
| 271 | ✗ | param_data[col] -= adaptive_lr * row_grad[col]; | |
| 272 | } | ||
| 273 | } | ||
| 274 | ✗ | } | |
| 275 | |||
| 276 | ✗ | void RowWiseAdaGrad::Init(const std::vector<std::string> table_name, | |
| 277 | const EmbeddingTableConfig& config, | ||
| 278 | BaseKV* base_kv) { | ||
| 279 | ✗ | for (const auto& name : table_name) { | |
| 280 | ✗ | SparseTensor* param_tensor = new SparseTensor(); | |
| 281 | ✗ | std::vector<uint64_t> shape = {config.num_embeddings, config.embedding_dim}; | |
| 282 | ✗ | TAG_TYPE tag = 0; | |
| 283 | ✗ | param_tensor->init( | |
| 284 | const_cast<std::string&>(name), PARAMETER, tag, shape, base_kv); | ||
| 285 | ✗ | tensor_map_[name] = param_tensor; | |
| 286 | |||
| 287 | ✗ | std::string acc_table_name = name + "_rowwise_accumulated_grad"; | |
| 288 | ✗ | SparseTensor* acc_tensor = new SparseTensor(); | |
| 289 | std::vector<uint64_t> acc_shape = { | ||
| 290 | ✗ | config.num_embeddings, 1}; // One value per row | |
| 291 | ✗ | acc_tensor->init( | |
| 292 | const_cast<std::string&>(acc_table_name), | ||
| 293 | MOMENT_1, | ||
| 294 | tag, | ||
| 295 | acc_shape, | ||
| 296 | base_kv); | ||
| 297 | ✗ | tensor_map_[acc_table_name] = acc_tensor; | |
| 298 | ✗ | } | |
| 299 | ✗ | } | |
| 300 | |||
| 301 | ✗ | void RowWiseAdaGrad::Update( | |
| 302 | std::string table, const ParameterCompressReader* reader, unsigned tid) { | ||
| 303 | ✗ | auto param_it = tensor_map_.find(table); | |
| 304 | ✗ | if (param_it == tensor_map_.end()) { | |
| 305 | ✗ | throw std::runtime_error("Table not found: " + table); | |
| 306 | } | ||
| 307 | |||
| 308 | ✗ | std::string acc_table = table + "_rowwise_accumulated_grad"; | |
| 309 | ✗ | auto acc_it = tensor_map_.find(acc_table); | |
| 310 | ✗ | if (acc_it == tensor_map_.end()) { | |
| 311 | ✗ | throw std::runtime_error( | |
| 312 | ✗ | "Row-wise accumulated gradient table not found: " + acc_table); | |
| 313 | } | ||
| 314 | |||
| 315 | ✗ | int size = reader->item_size(); | |
| 316 | ✗ | std::vector<uint64_t> keys = CollectReaderKeys(reader); | |
| 317 | |||
| 318 | ✗ | std::vector<base::ConstArray<float>> current_values; | |
| 319 | ✗ | std::vector<base::ConstArray<float>> acc_values; | |
| 320 | ✗ | param_it->second->BatchGet(keys, ¤t_values, tid); | |
| 321 | ✗ | acc_it->second->BatchGet(keys, &acc_values, tid); | |
| 322 | |||
| 323 | ✗ | for (int i = 0; i < size; ++i) { | |
| 324 | ✗ | const auto* item = reader->item(i); | |
| 325 | ✗ | if (current_values[i].Size() == 0 || acc_values[i].Size() == 0) { | |
| 326 | ✗ | continue; | |
| 327 | } | ||
| 328 | |||
| 329 | ✗ | float* param_data = const_cast<float*>(current_values[i].Data()); | |
| 330 | ✗ | float* acc_data = const_cast<float*>(acc_values[i].Data()); | |
| 331 | ✗ | int dim = std::min(current_values[i].Size(), item->dim); | |
| 332 | |||
| 333 | ✗ | float grad_square_mean = 0.0; | |
| 334 | ✗ | #pragma omp simd reduction(+ : grad_square_mean) | |
| 335 | for (int j = 0; j < dim; ++j) { | ||
| 336 | ✗ | grad_square_mean += item->data()[j] * item->data()[j]; | |
| 337 | } | ||
| 338 | ✗ | grad_square_mean /= dim; | |
| 339 | |||
| 340 | ✗ | acc_data[0] += grad_square_mean; | |
| 341 | |||
| 342 | ✗ | float adaptive_lr = learning_rate_ / (std::sqrt(acc_data[0]) + epsilon_); | |
| 343 | ✗ | #pragma omp simd | |
| 344 | for (int j = 0; j < dim; ++j) { | ||
| 345 | ✗ | param_data[j] -= adaptive_lr * item->data()[j]; | |
| 346 | } | ||
| 347 | } | ||
| 348 | ✗ | } | |
| 349 | |||
| 350 | ✗ | void RowWiseAdaGrad::UpdateFlat( | |
| 351 | std::string table, | ||
| 352 | const base::ConstArray<uint64_t>& keys, | ||
| 353 | const float* grads, | ||
| 354 | int64_t num_rows, | ||
| 355 | int64_t embedding_dim, | ||
| 356 | unsigned tid) { | ||
| 357 | ✗ | ValidateFlatUpdateArgs(keys, grads, num_rows, embedding_dim); | |
| 358 | |||
| 359 | ✗ | auto param_it = tensor_map_.find(table); | |
| 360 | ✗ | if (param_it == tensor_map_.end()) { | |
| 361 | ✗ | throw std::runtime_error("Table not found: " + table); | |
| 362 | } | ||
| 363 | |||
| 364 | ✗ | std::string acc_table = table + "_rowwise_accumulated_grad"; | |
| 365 | ✗ | auto acc_it = tensor_map_.find(acc_table); | |
| 366 | ✗ | if (acc_it == tensor_map_.end()) { | |
| 367 | ✗ | throw std::runtime_error( | |
| 368 | ✗ | "Row-wise accumulated gradient table not found: " + acc_table); | |
| 369 | } | ||
| 370 | |||
| 371 | ✗ | std::vector<uint64_t> key_vec(keys.Data(), keys.Data() + keys.Size()); | |
| 372 | ✗ | std::vector<base::ConstArray<float>> current_values; | |
| 373 | ✗ | std::vector<base::ConstArray<float>> acc_values; | |
| 374 | ✗ | param_it->second->BatchGet(key_vec, ¤t_values, tid); | |
| 375 | ✗ | acc_it->second->BatchGet(key_vec, &acc_values, tid); | |
| 376 | |||
| 377 | ✗ | for (int64_t row = 0; row < num_rows; ++row) { | |
| 378 | ✗ | const auto& current = current_values[static_cast<size_t>(row)]; | |
| 379 | ✗ | const auto& acc = acc_values[static_cast<size_t>(row)]; | |
| 380 | ✗ | if (current.Size() == 0 || acc.Size() == 0) { | |
| 381 | ✗ | continue; | |
| 382 | } | ||
| 383 | ✗ | if (static_cast<int64_t>(current.Size()) != embedding_dim || | |
| 384 | ✗ | acc.Size() != 1) { | |
| 385 | ✗ | throw std::runtime_error( | |
| 386 | ✗ | "RowWiseAdaGrad::UpdateFlat embedding_dim mismatch for table " + | |
| 387 | ✗ | table); | |
| 388 | } | ||
| 389 | |||
| 390 | ✗ | const float* row_grad = grads + row * embedding_dim; | |
| 391 | ✗ | float* param_data = const_cast<float*>(current.Data()); | |
| 392 | ✗ | float* acc_data = const_cast<float*>(acc.Data()); | |
| 393 | |||
| 394 | ✗ | float grad_square_mean = 0.0f; | |
| 395 | ✗ | #pragma omp simd reduction(+ : grad_square_mean) | |
| 396 | for (int64_t col = 0; col < embedding_dim; ++col) { | ||
| 397 | ✗ | grad_square_mean += row_grad[col] * row_grad[col]; | |
| 398 | } | ||
| 399 | ✗ | grad_square_mean /= static_cast<float>(embedding_dim); | |
| 400 | |||
| 401 | ✗ | acc_data[0] += grad_square_mean; | |
| 402 | ✗ | float adaptive_lr = learning_rate_ / (std::sqrt(acc_data[0]) + epsilon_); | |
| 403 | ✗ | #pragma omp simd | |
| 404 | for (int64_t col = 0; col < embedding_dim; ++col) { | ||
| 405 | ✗ | param_data[col] -= adaptive_lr * row_grad[col]; | |
| 406 | } | ||
| 407 | } | ||
| 408 | ✗ | } | |
| 409 |