GCC Code Coverage Report


Directory: src/
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 100.0% 1 / 0 / 1
Functions: 100.0% 1 / 0 / 1
Branches: -% 0 / 0 / 0

ps/rdma/rdma_ps_client_adapter.h
Line Branch Exec Source
1 #pragma once
2
3 #include <memory>
4 #include <mutex>
5 #include <string>
6 #include <thread>
7 #include <unordered_map>
8 #include <unordered_set>
9 #include <vector>
10
11 #include "base/json.h"
12 #include "ps/base/base_client.h"
13 #include "ps/rdma/petps_client.h"
14
15 namespace recstore {
16
17 void InitializeRdmaProcessRuntime();
18
19 class RDMAPSClientAdapter : public BasePSClient {
20 public:
21 explicit RDMAPSClientAdapter(json config);
22 8 ~RDMAPSClientAdapter() override = default;
23
24 int GetParameter(const base::ConstArray<uint64_t>& keys,
25 float* values) override;
26 int PutParameter(const base::ConstArray<uint64_t>& keys,
27 const std::vector<std::vector<float>>& values) override;
28 int UpdateParameter(const std::string& table_name,
29 const base::ConstArray<uint64_t>& keys,
30 const std::vector<std::vector<float>>* grads) override;
31 int UpdateParameterFlat(const std::string& table_name,
32 const base::ConstArray<uint64_t>& keys,
33 const float* grads,
34 int64_t num_rows,
35 int64_t embedding_dim) override;
36 int InitEmbeddingTable(const std::string& table_name,
37 const EmbeddingTableConfig& config) override;
38 int AsyncGetParameter(const base::ConstArray<uint64_t>& keys,
39 float* values) override;
40 void Command(PSCommand command) override;
41 uint64_t PrefetchParameter(const base::ConstArray<uint64_t>& keys) override;
42 bool IsPrefetchDone(uint64_t prefetch_id) override;
43 void WaitForPrefetch(uint64_t prefetch_id) override;
44 bool GetPrefetchResult(uint64_t prefetch_id,
45 std::vector<std::vector<float>>* values) override;
46 bool GetPrefetchResultFlat(uint64_t prefetch_id,
47 std::vector<float>* values,
48 int64_t* num_rows,
49 int64_t embedding_dim) override;
50
51 private:
52 struct TableState {
53 EmbeddingTableConfig config;
54 };
55
56 struct PendingShardRpc {
57 int shard_id = 0;
58 int client_index = 0;
59 int rpc_id = -1;
60 std::vector<std::size_t> original_positions;
61 void* recv_buffer = nullptr;
62 std::size_t key_count = 0;
63 };
64
65 struct BatchRequest {
66 float* user_buffer = nullptr;
67 bool assembled = false;
68 std::size_t total_key_count = 0;
69 std::int32_t status_code =
70 static_cast<std::int32_t>(petps::RpcStatus::kPending);
71 std::vector<PendingShardRpc> shard_rpcs;
72 };
73
74 struct ShardChunk {
75 int shard_id = 0;
76 int client_index = 0;
77 std::vector<uint64_t> keys;
78 std::vector<std::size_t> positions;
79 };
80
81 struct PrefetchState {
82 float* buffer = nullptr;
83 std::size_t buffer_id = 0;
84 int rpc_id = -1;
85 int64_t key_count = 0;
86 int64_t embedding_dim = 0;
87 bool borrowed_response = false;
88 bool batch_response = false;
89 };
90
91 void EnsureClientInitialized();
92 void EnsureThreadInitialized();
93 void EnsureTableReady(const std::string& table_name, int64_t embedding_dim);
94 int64_t DefaultEmbeddingDimOrThrow() const;
95 std::size_t MaxGetKeysPerRpc() const;
96 std::size_t MaxInFlightGetRpcs() const;
97 int PartitionKey(uint64_t key) const;
98 std::vector<ShardChunk> BuildChunks(base::ConstArray<uint64_t> keys) const;
99 bool FinalizeBatchIfNeeded(BatchRequest* batch);
100 void
101 WaitShardRpcsCooperatively(const std::vector<PendingShardRpc>& shard_rpcs);
102 int SubmitGetParameter(base::ConstArray<uint64_t> keys,
103 float* values,
104 bool isAsync,
105 int async_req_id);
106 bool QueryRPCFinished(int rpc_id);
107 void WaitRPCFinish(int rpc_id);
108 void RevokeRPCResource(int rpc_id);
109 float* AcquirePrefetchBuffer(std::size_t bytes, std::size_t* buffer_id);
110 void ReleasePrefetchBuffer(std::size_t buffer_id);
111 const float* BorrowPrefetchResult(const PrefetchState& state,
112 std::int32_t* status_code,
113 std::size_t* response_bytes);
114 PrefetchState GetPrefetchState(uint64_t prefetch_id);
115 void MarkPrefetchConsumed(uint64_t prefetch_id);
116
117 json config_;
118 std::mutex init_mu_;
119 std::mutex thread_init_mu_;
120 std::mutex state_mu_;
121 bool initialized_ = false;
122 std::unordered_set<std::thread::id> initialized_threads_;
123 std::vector<std::unique_ptr<petps::PetPSClient>> shard_clients_;
124 BaseParameterClient* client_ = nullptr;
125 int num_shards_ = 1;
126 std::string hash_method_ = "city_hash";
127 std::unordered_map<int, int> shard_to_client_index_;
128 int batch_rpc_id_acc_ = -1;
129 mutable std::mutex batches_mu_;
130 std::unordered_map<int, BatchRequest> batches_;
131 std::unordered_map<std::string, TableState> tables_;
132 std::vector<std::unique_ptr<char[]>> prefetch_buffers_;
133 std::vector<std::size_t> prefetch_buffer_capacities_;
134 std::vector<std::size_t> free_prefetch_buffer_ids_;
135 std::unordered_map<uint64_t, PrefetchState> prefetches_;
136 uint64_t next_prefetch_id_ = 1;
137 };
138
139 } // namespace recstore
140