GCC Code Coverage Report


Directory: src/
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 100.0% 1 / 0 / 1
Functions: 100.0% 1 / 0 / 1
Branches: -% 0 / 0 / 0

optimizer/sparse_tensor.h
Line Branch Exec Source
1 #include <string>
2 #include <vector>
3 #include "../storage/kv_engine/base_kv.h"
4
5 #define TAG_TYPE uint8_t
6
7 enum TensorType { PARAMETER = 0, MOMENT_1 = 1, MOMENT_2 = 2 };
8
9 class SparseTensor {
10 private:
11 std::string name;
12 TensorType type;
13 TAG_TYPE tag;
14 std::vector<uint64_t> shape;
15 BaseKV* kv;
16 uint64_t concatKeyAndTag(uint64_t key, TAG_TYPE tag);
17
18 public:
19 24 SparseTensor() = default;
20 void init(std::string& name,
21 TensorType type,
22 TAG_TYPE tag,
23 std::vector<uint64_t>& shape,
24 BaseKV* kv);
25 void Get(const uint64_t key, std::string& value, unsigned tid);
26 void Put(const uint64_t key, const std::string_view& value, unsigned tid);
27 void BatchGet(const std::vector<uint64_t>& keys,
28 std::vector<base::ConstArray<float>>* values,
29 unsigned tid);
30 int64_t EmbeddingDim() const;
31 bool ApplySgdUpdateFlat(
32 const base::ConstArray<uint64_t>& keys,
33 const float* grads,
34 int64_t num_rows,
35 int64_t embedding_dim,
36 float learning_rate,
37 unsigned tid);
38 };
39