GCC Code Coverage Report


Directory: src/
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 40.0% 14 / 0 / 35
Functions: 42.9% 3 / 0 / 7
Branches: 18.8% 3 / 0 / 16

optimizer/sparse_tensor.cpp
Line Branch Exec Source
1 #include "sparse_tensor.h"
2
3 uint64_t SparseTensor::concatKeyAndTag(uint64_t key, TAG_TYPE tag) {
4 constexpr int tag_bits = sizeof(TAG_TYPE) * 8;
5 constexpr int shift = (sizeof(uint64_t) * 8) - tag_bits;
6 key &= (~0ULL >> tag_bits);
7 return (static_cast<uint64_t>(tag) << shift) | key;
8 }
9
10 24 void SparseTensor::init(std::string& name,
11 TensorType type,
12 TAG_TYPE tag,
13 std::vector<uint64_t>& shape,
14 BaseKV* kv) {
15 24 this->name = name;
16 24 this->type = type;
17 24 this->tag = tag;
18 24 this->shape = shape;
19 24 this->kv = kv;
20 24 }
21
22 void SparseTensor::Get(const uint64_t key, std::string& value, unsigned tid) {
23 auto _key = concatKeyAndTag(key, tag);
24 kv->Get(_key, value, tid);
25 }
26
27 void SparseTensor::Put(
28 const uint64_t key, const std::string_view& value, unsigned tid) {
29 auto _key = concatKeyAndTag(key, tag);
30 kv->Put(_key, value, tid);
31 }
32
33 void SparseTensor::BatchGet(const std::vector<uint64_t>& keys,
34 std::vector<base::ConstArray<float>>* values,
35 unsigned tid) {
36 std::vector<uint64_t> hashed_keys;
37 hashed_keys.reserve(keys.size());
38 for (auto k : keys) {
39 hashed_keys.push_back(concatKeyAndTag(k, tag));
40 }
41 kv->BatchGet(base::ConstArray<uint64_t>(hashed_keys), values, tid);
42 }
43
44 6 int64_t SparseTensor::EmbeddingDim() const {
45
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
6 if (shape.size() < 2) {
46 return 0;
47 }
48 6 return static_cast<int64_t>(shape[1]);
49 }
50
51 6 bool SparseTensor::ApplySgdUpdateFlat(
52 const base::ConstArray<uint64_t>& keys,
53 const float* grads,
54 int64_t num_rows,
55 int64_t embedding_dim,
56 float learning_rate,
57 unsigned tid) {
58
1/2
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
12 return kv != nullptr &&
59
1/2
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
6 kv->ApplySgdUpdateFlat(
60 12 keys, grads, num_rows, embedding_dim, learning_rate, tag, tid);
61 }
62