GCC Code Coverage Report


Directory: src/
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 100.0% 6 / 0 / 6
Functions: 100.0% 3 / 0 / 3
Branches: 50.0% 6 / 0 / 12

ps/base/base_client.h
Line Branch Exec Source
1 #pragma once
2 #include <vector>
3 #include <string>
4 #include <tuple>
5
6 #include "base/array.h"
7 #include "base/json.h"
8 #include "base/log.h"
9
10 namespace recstore {
11 struct EmbeddingTableConfig {
12 uint64_t num_embeddings, embedding_dim;
13
14 20 std::string Serialize() const {
15 nlohmann::json payload{
16
5/10
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✓ Branch 5 taken 20 times.
✗ Branch 6 not taken.
✓ Branch 8 taken 20 times.
✗ Branch 9 not taken.
✓ Branch 12 taken 20 times.
✗ Branch 13 not taken.
✓ Branch 15 taken 20 times.
✗ Branch 16 not taken.
180 {"num_embeddings", num_embeddings}, {"embedding_dim", embedding_dim}};
17
1/2
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
40 return payload.dump();
18 20 }
19 };
20
21 enum class PSCommand {
22 CLEAR_PS,
23 RELOAD_PS,
24 LOAD_FAKE_DATA,
25 DUMP_FAKE_DATA,
26 };
27
28 class BasePSClient {
29 json json_config_;
30
31 public:
32 148 explicit BasePSClient(json config) : json_config_(config) {}
33 148 virtual ~BasePSClient() {}
34
35 virtual int
36 GetParameter(const base::ConstArray<uint64_t>& keys, float* values) = 0;
37
38 virtual int PutParameter(const base::ConstArray<uint64_t>& keys,
39 const std::vector<std::vector<float>>& values) = 0;
40 virtual int UpdateParameter(const std::string& table_name,
41 const base::ConstArray<uint64_t>& keys,
42 const std::vector<std::vector<float>>* grads) = 0;
43 virtual int UpdateParameterFlat(
44 const std::string& table_name,
45 const base::ConstArray<uint64_t>& keys,
46 const float* grads,
47 int64_t num_rows,
48 int64_t embedding_dim) = 0;
49
50 virtual int InitEmbeddingTable(const std::string& table_name,
51 const EmbeddingTableConfig& config) = 0;
52 virtual int
53 AsyncGetParameter(const base::ConstArray<uint64_t>& keys, float* values) = 0;
54
55 virtual void Command(PSCommand command) = 0;
56
57 virtual uint64_t
58 PrefetchParameter(const base::ConstArray<uint64_t>& keys) = 0;
59 virtual bool IsPrefetchDone(uint64_t prefetch_id) = 0;
60 virtual void WaitForPrefetch(uint64_t prefetch_id) = 0;
61 virtual bool GetPrefetchResult(uint64_t prefetch_id,
62 std::vector<std::vector<float>>* values) = 0;
63 virtual bool GetPrefetchResultFlat(
64 uint64_t prefetch_id,
65 std::vector<float>* values,
66 int64_t* num_rows,
67 int64_t embedding_dim) = 0;
68 };
69
70 } // namespace recstore
71