Skip to content

Commit 63a03ee

Browse files
authored
[feature]2.2 custom_allreduce support cudagraph recapture (#4307)
* custom_allreduce support cudagraph recapture * delete code * add shut_down/restart default group
1 parent 9cc2c99 commit 63a03ee

File tree

7 files changed

+31
-3
lines changed

7 files changed

+31
-3
lines changed

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,8 @@ int64_t open_mem_handle(paddle::Tensor& mem_handle);
616616

617617
void free_shared_buffer(int64_t buffer);
618618

619+
void clear_ipc_handles(int64_t _fa);
620+
619621
// speculative decoding Kernel
620622
std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
621623
const paddle::Tensor& input_ids,
@@ -1204,6 +1206,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
12041206

12051207
m.def("free_shared_buffer", &free_shared_buffer, "free_shared_buffer");
12061208

1209+
m.def("clear_ipc_handles", &clear_ipc_handles, "clear_ipc_handles");
1210+
12071211
m.def("open_mem_handle", &open_mem_handle, "open_mem_handle");
12081212

12091213
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "get_graph_buffer_ipc_meta");

custom_ops/gpu_ops/custom_all_reduce/all_reduce.cu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,14 @@ void register_graph_buffers(fptr_t _fa,
122122
for (int i = 0; i < handles.size(); i++) {
123123
bytes.emplace_back(handles[i].begin(), handles[i].end());
124124
}
125-
bytes.reserve(handles.size());
126125
fa->register_graph_buffers(bytes, offsets);
127126
}
128127

128+
void clear_ipc_handles(fptr_t _fa) {
129+
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
130+
fa->clear_ipc_handles();
131+
}
132+
129133
std::tuple<fptr_t, paddle::Tensor> allocate_shared_buffer_and_handle(
130134
int64_t size) {
131135

custom_ops/gpu_ops/custom_all_reduce/all_reduce.cuh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -517,10 +517,15 @@ class CustomAllreduce {
517517
#undef KL
518518
}
519519

520-
~CustomAllreduce() {
520+
void clear_ipc_handles(){
521521
for (auto [_, ptr] : ipc_handles_) {
522522
CUDACHECK(cudaIpcCloseMemHandle(ptr));
523523
}
524+
ipc_handles_.clear();
525+
}
526+
527+
~CustomAllreduce() {
528+
clear_ipc_handles();
524529
}
525530
};
526531
} // namespace paddle

fastdeploy/distributed/communication.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@ def use_custom_allreduce(custom_all_reduce_max_bytes: int = 8192 * 1024):
4242
_TP_AR = CustomAllreduce(model_parallel_group, custom_all_reduce_max_bytes)
4343

4444

45+
def custom_ar_clear_ipc_handles():
46+
global _TP_AR
47+
if _TP_AR is not None:
48+
_TP_AR.clear_ipc_handles()
49+
50+
4551
try:
4652

4753
@paddle.jit.marker.unified

fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from fastdeploy.distributed.custom_all_reduce import cuda_wrapper
2626
from fastdeploy.model_executor.ops.gpu import (
2727
all_reduce,
28+
clear_ipc_handles,
2829
dispose,
2930
get_graph_buffer_ipc_meta,
3031
init_custom_all_reduce,
@@ -220,6 +221,9 @@ def custom_all_reduce(self, input: paddle.Tensor) -> Optional[paddle.Tensor]:
220221
else:
221222
return self.all_reduce(input, input, registered=False)
222223

224+
def clear_ipc_handles(self):
225+
clear_ipc_handles(self._ptr)
226+
223227
def close(self):
224228
if self._ptr:
225229
dispose(self._ptr)

fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
from paddle.device.cuda import graphs
2424

2525
from fastdeploy.config import FDConfig
26-
from fastdeploy.distributed.communication import capture_custom_allreduce
26+
from fastdeploy.distributed.communication import (
27+
capture_custom_allreduce,
28+
custom_ar_clear_ipc_handles,
29+
)
2730
from fastdeploy.utils import get_logger
2831

2932
logger = get_logger("cudagrpah_piecewise_backend", "cudagraph_piecewise_backend.log")
@@ -208,6 +211,7 @@ def _create_entry_dict(self):
208211
def clear_graph(self):
209212
""" """
210213
# Clear graphs
214+
custom_ar_clear_ipc_handles()
211215
for id, entry in self.concrete_size_entries.items():
212216
if entry.cuda_graph:
213217
del entry.cuda_graph

fastdeploy/rl/dynamic_weight_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def update_parameters(self, pid: int = 0) -> None:
6666
paddle.device.cuda.empty_cache()
6767

6868
if not self.first_load:
69+
paddle.distributed.restart_process_group()
6970
paddle.distributed.restart_process_group(self.parallel_config.tp_group)
7071
if self.parallel_config.enable_expert_parallel:
7172
paddle.distributed.restart_process_group(self.parallel_config.ep_group)

0 commit comments

Comments
 (0)