ps/rdma/base_client.h
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #pragma once | ||
| 2 | #include <memory> | ||
| 3 | #include <vector> | ||
| 4 | #include <string> | ||
| 5 | |||
| 6 | #include "base/array.h" | ||
| 7 | |||
| 8 | class BaseParameterClient { | ||
| 9 | public: | ||
| 10 | 30 | explicit BaseParameterClient(const std::string& host, int port, int shard) | |
| 11 | 30 | : host_(host), port_(port), shard_(shard) {} | |
| 12 | 30 | virtual ~BaseParameterClient() {} | |
| 13 | |||
| 14 | virtual int GetParameter(base::ConstArray<uint64_t> keys, | ||
| 15 | std::vector<std::vector<float>>* values) = 0; | ||
| 16 | |||
| 17 | virtual int GetParameter(base::ConstArray<uint64_t> keys, | ||
| 18 | float* values, | ||
| 19 | bool isAsync, | ||
| 20 | int async_req_id = 0) = 0; | ||
| 21 | |||
| 22 | virtual void InitThread() = 0; | ||
| 23 | |||
| 24 | ✗ | virtual void Barrier(const std::string& ss, int k) { | |
| 25 | ✗ | LOG(FATAL) << "not implementation"; | |
| 26 | } | ||
| 27 | |||
| 28 | // message buffer for received embeddings | ||
| 29 | virtual void* GetReceiveBuffer(size_t size) = 0; | ||
| 30 | |||
| 31 | virtual bool QueryRPCFinished(int rpc_id) = 0; | ||
| 32 | |||
| 33 | virtual void WaitRPCFinish(int rpc_id) = 0; | ||
| 34 | |||
| 35 | virtual void RevokeRPCResource(int rpc_id) = 0; | ||
| 36 | |||
| 37 | virtual int PutParameter(const std::vector<uint64_t>& keys, | ||
| 38 | const std::vector<std::vector<float>>& values) = 0; | ||
| 39 | |||
| 40 | ✗ | virtual int InitEmbeddingTable(const std::string& table_name, | |
| 41 | std::uint64_t num_embeddings, | ||
| 42 | std::uint64_t embedding_dim) { | ||
| 43 | ✗ | LOG(FATAL) << "not Implement"; | |
| 44 | return -1; | ||
| 45 | } | ||
| 46 | |||
| 47 | ✗ | virtual int UpdateParameter(const std::string& table_name, | |
| 48 | base::ConstArray<uint64_t> keys, | ||
| 49 | const std::vector<std::vector<float>>* grads) { | ||
| 50 | ✗ | LOG(FATAL) << "not Implement"; | |
| 51 | return -1; | ||
| 52 | } | ||
| 53 | |||
| 54 | ✗ | virtual int FakePutParameter(base::ConstArray<uint64_t> keys, float* values) { | |
| 55 | ✗ | LOG(FATAL) << "not Implement"; | |
| 56 | return 0; | ||
| 57 | } | ||
| 58 | |||
| 59 | protected: | ||
| 60 | std::string host_; | ||
| 61 | int port_; | ||
| 62 | int shard_; | ||
| 63 | }; | ||
| 64 |