optimizer/optimizer.h
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #pragma once | ||
| 2 | |||
| 3 | #include <string> | ||
| 4 | #include <vector> | ||
| 5 | #include <unordered_map> | ||
| 6 | #include <cmath> | ||
| 7 | #include <stdexcept> | ||
| 8 | #include "sparse_tensor.h" | ||
| 9 | #include "ps/base/base_client.h" | ||
| 10 | #include "ps/base/parameters.h" | ||
| 11 | |||
| 12 | using ::ParameterCompressReader; | ||
| 13 | using recstore::EmbeddingTableConfig; | ||
| 14 | |||
| 15 | class Optimizer { | ||
| 16 | protected: | ||
| 17 | std::unordered_map<std::string, SparseTensor*> tensor_map_; | ||
| 18 | |||
| 19 | public: | ||
| 20 | 24 | virtual ~Optimizer() { | |
| 21 |
2/2✓ Branch 5 taken 24 times.
✓ Branch 6 taken 24 times.
|
48 | for (auto& pair : tensor_map_) { |
| 22 |
1/2✓ Branch 0 taken 24 times.
✗ Branch 1 not taken.
|
24 | delete pair.second; |
| 23 | } | ||
| 24 | 24 | } | |
| 25 | |||
| 26 | virtual void Init(const std::vector<std::string> table_name, | ||
| 27 | const EmbeddingTableConfig& config, | ||
| 28 | BaseKV* base_kv) = 0; | ||
| 29 | |||
| 30 | virtual void Update(std::string table, | ||
| 31 | const ParameterCompressReader* reader, | ||
| 32 | unsigned tid) = 0; | ||
| 33 | virtual void UpdateFlat( | ||
| 34 | std::string table, | ||
| 35 | const base::ConstArray<uint64_t>& keys, | ||
| 36 | const float* grads, | ||
| 37 | int64_t num_rows, | ||
| 38 | int64_t embedding_dim, | ||
| 39 | unsigned tid) = 0; | ||
| 40 | }; | ||
| 41 | |||
| 42 | class SGD : public Optimizer { | ||
| 43 | private: | ||
| 44 | float learning_rate_; | ||
| 45 | |||
| 46 | public: | ||
| 47 | 24 | explicit SGD(float lr = 0.01) : learning_rate_(lr) {} | |
| 48 | |||
| 49 | void Init(const std::vector<std::string> table_name, | ||
| 50 | const EmbeddingTableConfig& config, | ||
| 51 | BaseKV* base_kv) override; | ||
| 52 | void Update(std::string table, | ||
| 53 | const ParameterCompressReader* reader, | ||
| 54 | unsigned tid) override; | ||
| 55 | void UpdateFlat(std::string table, | ||
| 56 | const base::ConstArray<uint64_t>& keys, | ||
| 57 | const float* grads, | ||
| 58 | int64_t num_rows, | ||
| 59 | int64_t embedding_dim, | ||
| 60 | unsigned tid) override; | ||
| 61 | }; | ||
| 62 | |||
| 63 | class AdaGrad : public Optimizer { | ||
| 64 | private: | ||
| 65 | float learning_rate_; | ||
| 66 | float epsilon_; | ||
| 67 | |||
| 68 | public: | ||
| 69 | explicit AdaGrad(float lr = 0.01, float epsilon = 1e-10) | ||
| 70 | : learning_rate_(lr), epsilon_(epsilon) {} | ||
| 71 | |||
| 72 | void Init(const std::vector<std::string> table_name, | ||
| 73 | const EmbeddingTableConfig& config, | ||
| 74 | BaseKV* base_kv) override; | ||
| 75 | void Update(std::string table, | ||
| 76 | const ParameterCompressReader* reader, | ||
| 77 | unsigned tid) override; | ||
| 78 | void UpdateFlat(std::string table, | ||
| 79 | const base::ConstArray<uint64_t>& keys, | ||
| 80 | const float* grads, | ||
| 81 | int64_t num_rows, | ||
| 82 | int64_t embedding_dim, | ||
| 83 | unsigned tid) override; | ||
| 84 | }; | ||
| 85 | |||
| 86 | class RowWiseAdaGrad : public Optimizer { | ||
| 87 | private: | ||
| 88 | float learning_rate_; | ||
| 89 | float epsilon_; | ||
| 90 | |||
| 91 | public: | ||
| 92 | explicit RowWiseAdaGrad(float lr = 0.01, float epsilon = 1e-10) | ||
| 93 | : learning_rate_(lr), epsilon_(epsilon) {} | ||
| 94 | |||
| 95 | void Init(const std::vector<std::string> table_name, | ||
| 96 | const EmbeddingTableConfig& config, | ||
| 97 | BaseKV* base_kv) override; | ||
| 98 | void Update(std::string table, | ||
| 99 | const ParameterCompressReader* reader, | ||
| 100 | unsigned tid) override; | ||
| 101 | void UpdateFlat(std::string table, | ||
| 102 | const base::ConstArray<uint64_t>& keys, | ||
| 103 | const float* grads, | ||
| 104 | int64_t num_rows, | ||
| 105 | int64_t embedding_dim, | ||
| 106 | unsigned tid) override; | ||
| 107 | }; | ||
| 108 |