Skip to content

Commit c51a6bf

Browse files
authored
Add an async interface to the AllGather implementation (#551)
Authors: - Lawrence Mitchell (https://github.com/wence-) Approvers: - Mads R. B. Kristensen (https://github.com/madsbk) URL: #551
1 parent 13d0448 commit c51a6bf

38 files changed

+607
-84
lines changed

cpp/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,12 @@ add_library(
168168
if(RAPIDSMPF_HAVE_STREAMING)
169169
target_sources(
170170
rapidsmpf
171-
PRIVATE src/streaming/core/context.cpp
171+
PRIVATE src/streaming/coll/allgather.cpp
172+
src/streaming/coll/shuffler.cpp
173+
src/streaming/core/context.cpp
172174
src/streaming/core/leaf_node.cpp
173175
src/streaming/core/node.cpp
174176
src/streaming/cudf/partition.cpp
175-
src/streaming/cudf/shuffler.cpp
176177
src/streaming/cudf/table_chunk.cpp
177178
)
178179
endif()

cpp/benchmarks/streaming/bench_streaming_shuffle.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
#include <rapidsmpf/nvtx.hpp>
1919
#include <rapidsmpf/shuffler/shuffler.hpp>
2020
#include <rapidsmpf/statistics.hpp>
21+
#include <rapidsmpf/streaming/coll/shuffler.hpp>
2122
#include <rapidsmpf/streaming/core/channel.hpp>
2223
#include <rapidsmpf/streaming/core/context.hpp>
2324
#include <rapidsmpf/streaming/core/node.hpp>
2425
#include <rapidsmpf/streaming/cudf/partition.hpp>
25-
#include <rapidsmpf/streaming/cudf/shuffler.hpp>
2626
#include <rapidsmpf/streaming/cudf/table_chunk.hpp>
2727
#include <rapidsmpf/utils.hpp>
2828

cpp/include/rapidsmpf/allgather/allgather.hpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <chrono>
99
#include <condition_variable>
1010
#include <cstdint>
11+
#include <functional>
1112
#include <limits>
1213
#include <memory>
1314
#include <mutex>
@@ -373,9 +374,10 @@ class AllGather {
373374
/**
374375
* @brief Insert packed data into the allgather operation.
375376
*
377+
* @param sequence_number Local ordered sequence number of the data.
376378
* @param packed_data The data to contribute to the allgather.
377379
*/
378-
void insert(PackedData&& packed_data);
380+
void insert(std::uint64_t sequence_number, PackedData&& packed_data);
379381

380382
/**
381383
* @brief Mark that this rank has finished contributing data.
@@ -445,6 +447,9 @@ class AllGather {
445447
* @param br Buffer resource for memory allocation.
446448
* @param statistics Statistics collection instance (disabled by
447449
* default).
450+
* @param finished_callback Optional callback run when partitions are locally
451+
* finished. The callback is guaranteed to be called by the progress thread exactly
452+
* once when the allgather is locally ready.
448453
*
449454
* @note The caller promises that inserted buffers are stream-ordered with respect
450455
* to their own stream, and extracted buffers are likewise guaranteed to be stream-
@@ -455,7 +460,8 @@ class AllGather {
455460
std::shared_ptr<ProgressThread> progress_thread,
456461
OpID op_id,
457462
BufferResource* br,
458-
std::shared_ptr<Statistics> statistics = Statistics::disabled()
463+
std::shared_ptr<Statistics> statistics = Statistics::disabled(),
464+
std::function<void(void)>&& finished_callback = nullptr
459465
);
460466

461467
/// @brief Deleted copy constructor.
@@ -524,8 +530,10 @@ class AllGather {
524530
progress_thread_; ///< Progress thread for async operations
525531
BufferResource* br_; ///< Buffer resource for memory allocation
526532
std::shared_ptr<Statistics> statistics_; ///< Statistics collection instance
533+
std::function<void(void)> finished_callback_{
534+
nullptr
535+
}; ///< Optional callback to run when allgather is finished and ready for extraction.
527536
std::atomic<Rank> finish_counter_; ///< Counter for finish markers received
528-
std::atomic<std::uint64_t> sequence_number_; ///< Sequence number for chunks
529537
std::atomic<std::uint32_t> nlocal_insertions_; ///< Number of local data insertions
530538
OpID op_id_; ///< Unique operation identifier
531539
std::atomic<bool> locally_finished_{false}; ///< Whether this rank has finished

cpp/include/rapidsmpf/buffer/packed_data.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <memory>
99
#include <vector>
1010

11+
#include <rmm/cuda_stream_view.hpp>
1112
#include <rmm/device_buffer.hpp>
1213

1314
#include <rapidsmpf/buffer/buffer.hpp>
@@ -68,6 +69,15 @@ struct PackedData {
6869
[[nodiscard]] bool empty() const {
6970
return metadata->empty() && data->size == 0;
7071
}
72+
73+
/**
74+
* @brief Get the stream associated with the data buffer.
75+
*
76+
* @return The CUDA stream.
77+
*/
78+
[[nodiscard]] rmm::cuda_stream_view stream() const {
79+
return data->stream();
80+
}
7181
};
7282

7383
} // namespace rapidsmpf
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/**
2+
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
#pragma once
7+
8+
#include <cstdint>
9+
10+
#include <rapidsmpf/buffer/packed_data.hpp>
11+
12+
namespace rapidsmpf::streaming {
13+
14+
/**
15+
* @brief Chunk of `PackedData` with sequence number.
16+
*/
17+
struct PackedDataChunk {
18+
/**
19+
* @brief Sequence number used to preserve chunk ordering.
20+
*/
21+
std::uint64_t sequence_number;
22+
23+
/**
24+
* @brief Packed data payload.
25+
*/
26+
PackedData data;
27+
};
28+
29+
} // namespace rapidsmpf::streaming
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/**
2+
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
#pragma once
7+
8+
#include <cstdint>
9+
#include <unordered_map>
10+
#include <vector>
11+
12+
#include <rapidsmpf/buffer/packed_data.hpp>
13+
#include <rapidsmpf/shuffler/chunk.hpp>
14+
15+
namespace rapidsmpf::streaming {
16+
17+
/**
18+
* @brief Chunk of packed partitions identified by partition ID.
19+
*
20+
* Represents a single unit of work in a streaming pipeline where each partition
21+
* is associated with a `PartID` and contains packed (serialized) data.
22+
*/
23+
struct PartitionMapChunk {
24+
/**
25+
* @brief Sequence number used to preserve chunk ordering.
26+
*/
27+
std::uint64_t sequence_number;
28+
29+
/**
30+
* @brief Packed data for each partition, keyed by partition ID.
31+
*/
32+
std::unordered_map<shuffler::PartID, PackedData> data;
33+
};
34+
35+
/**
36+
* @brief Chunk of packed partitions stored as a vector.
37+
*
38+
* Represents a single unit of work in a streaming pipeline where the partitions
39+
* are stored in a vector.
40+
*/
41+
struct PartitionVectorChunk {
42+
/**
43+
* @brief Sequence number used to preserve chunk ordering.
44+
*/
45+
std::uint64_t sequence_number;
46+
47+
/**
48+
* @brief Packed data for each partition stored in a vector.
49+
*/
50+
std::vector<PackedData> data;
51+
};
52+
53+
} // namespace rapidsmpf::streaming
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/**
2+
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
#pragma once
7+
8+
#include <memory>
9+
#include <vector>
10+
11+
#include <rapidsmpf/allgather/allgather.hpp>
12+
#include <rapidsmpf/buffer/packed_data.hpp>
13+
#include <rapidsmpf/communicator/communicator.hpp>
14+
#include <rapidsmpf/streaming/chunks/packed_data.hpp>
15+
#include <rapidsmpf/streaming/core/channel.hpp>
16+
#include <rapidsmpf/streaming/core/context.hpp>
17+
18+
#include <coro/event.hpp>
19+
#include <coro/task.hpp>
20+
21+
namespace rapidsmpf::streaming {
22+
23+
/**
24+
* @brief Asynchronous (coroutine) interface to `allgather::AllGather`.
25+
*
26+
* Once the AllGather is created, many tasks may insert data into it. If multiple tasks
27+
* insert data, the user is responsible for arranging that `insert_finished` is only
28+
* called after all `insert`ions have completed. A single consumer task should extract
29+
* data.
30+
*/
31+
class AllGather {
32+
public:
33+
/// @copydoc allgather::AllGather::Ordered
34+
using Ordered = rapidsmpf::allgather::AllGather::Ordered;
35+
/**
36+
* @brief Construct an asynchronous allgather.
37+
*
38+
* @param ctx Streaming context
39+
* @param op_id Unique identifier for the allgather.
40+
*/
41+
AllGather(std::shared_ptr<Context> ctx, OpID op_id);
42+
43+
AllGather(AllGather const&) = delete;
44+
AllGather& operator=(AllGather const&) = delete;
45+
AllGather(AllGather&&) = delete;
46+
AllGather& operator=(AllGather&&) = delete;
47+
48+
~AllGather();
49+
50+
/**
51+
* @brief Gets the streaming context associated with this AllGather object.
52+
*
53+
* @return Shared pointer to context.
54+
*/
55+
[[nodiscard]] std::shared_ptr<Context> ctx() const noexcept;
56+
57+
/**
58+
* @brief Insert a chunk into the allgather.
59+
*
60+
* @param chunk The chunk to insert holding data and a sequence number.
61+
*/
62+
void insert(PackedDataChunk&& chunk);
63+
64+
/// @copydoc rapidsmpf::allgather::AllGather::insert_finished()
65+
void insert_finished();
66+
67+
/**
68+
* @brief Extract all gathered data.
69+
*
70+
* @param ordered If the extracted data should be ordered. If ordered, return data
71+
* will be ordered first by rank and then by sequence number of the inserted chunks on
72+
* that rank.
73+
*
74+
* @return Coroutine that completes when all data is available for extraction and
75+
* returns the data.
76+
*/
77+
coro::task<std::vector<PackedDataChunk>> extract_all(Ordered ordered = Ordered::YES);
78+
79+
private:
80+
coro::event
81+
event_{}; ///< Event tracking whether all data has arrived and can be extracted.
82+
std::shared_ptr<Context> ctx_; ///< Streaming context.
83+
allgather::AllGather gatherer_; ///< Underlying collective allgather.
84+
};
85+
86+
namespace node {
87+
88+
/**
89+
* @brief Create an allgather node for a single allgather operation.
90+
*
91+
* This is a streaming version of `rapidsmpf::allgather::AllGather` that operates on
92+
* packed data received through `Channel`s.
93+
*
94+
* @param ctx The streaming context to use.
95+
* @param ch_in Input channel providing `PackedDataChunk`s to be gathered.
96+
* @param ch_out Output channel where the gathered `PackedDataChunk`s are sent.
97+
* @param op_id Unique identifier for the operation.
98+
* @param ordered If the extracted data should be sent to the output channel with sequence
99+
* numbers corresponding to the global total order of input chunks. If yes, then the
100+
* sequence numbers of the extracted data will be ordered first by rank and then by input
101+
* sequence number. If no, the sequence number of the extracted chunks will have no
102+
* relation to any input sequence order.
103+
*
104+
* @return A streaming node that completes when the allgather is finished and the output
105+
* channel is drained.
106+
*/
107+
Node allgather(
108+
std::shared_ptr<Context> ctx,
109+
std::shared_ptr<Channel> ch_in,
110+
std::shared_ptr<Channel> ch_out,
111+
OpID op_id,
112+
AllGather::Ordered ordered = AllGather::Ordered::YES
113+
);
114+
} // namespace node
115+
} // namespace rapidsmpf::streaming

cpp/include/rapidsmpf/streaming/cudf/partition.hpp

Lines changed: 1 addition & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -5,58 +5,17 @@
55

66
#pragma once
77

8-
#include <unordered_map>
8+
#include <cstdint>
99
#include <vector>
1010

1111
#include <cudf/partitioning.hpp>
12-
#include <cudf/table/table.hpp>
1312

14-
#include <rapidsmpf/buffer/packed_data.hpp>
15-
#include <rapidsmpf/shuffler/shuffler.hpp>
16-
#include <rapidsmpf/statistics.hpp>
1713
#include <rapidsmpf/streaming/core/channel.hpp>
1814
#include <rapidsmpf/streaming/core/context.hpp>
1915
#include <rapidsmpf/streaming/core/node.hpp>
20-
#include <rapidsmpf/streaming/cudf/table_chunk.hpp>
2116

2217
namespace rapidsmpf::streaming {
2318

24-
/**
25-
* @brief Chunk of packed partitions identified by partition ID.
26-
*
27-
* Represents a single unit of work in a streaming pipeline where each partition
28-
* is associated with a `PartID` and contains packed (serialized) data.
29-
*/
30-
struct PartitionMapChunk {
31-
/**
32-
* @brief Sequence number used to preserve chunk ordering.
33-
*/
34-
std::uint64_t sequence_number;
35-
36-
/**
37-
* @brief Packed data for each partition, keyed by partition ID.
38-
*/
39-
std::unordered_map<shuffler::PartID, PackedData> data;
40-
};
41-
42-
/**
43-
* @brief Chunk of packed partitions stored as a vector.
44-
*
45-
* Represents a single unit of work in a streaming pipeline where the partitions
46-
* are stored in a vector.
47-
*/
48-
struct PartitionVectorChunk {
49-
/**
50-
* @brief Sequence number used to preserve chunk ordering.
51-
*/
52-
std::uint64_t sequence_number;
53-
54-
/**
55-
* @brief Packed data for each partition stored in a vector.
56-
*/
57-
std::vector<PackedData> data;
58-
};
59-
6019
namespace node {
6120

6221
/**

0 commit comments

Comments
 (0)