@@ -60,6 +60,74 @@ NB_MODULE(_weight_synchronizer, m) {
6060 }
6161 },
6262 nb::arg (" source" ), nb::call_guard<nb::gil_scoped_release>())
63+ .def (
64+ " D2h" ,
65+ [](WeightSynchronizer& self) {
66+ auto status_or_future = self.D2h ();
67+ if (!status_or_future.ok ()) {
68+ throw std::runtime_error (
69+ " WeightSynchronizer D2H failed: " +
70+ std::string (status_or_future.status ().message ()));
71+ }
72+ absl::Status status = status_or_future.value ().Await ().status ();
73+ if (!status.ok ()) {
74+ throw std::runtime_error (" WeightSynchronizer D2H copy failed: " +
75+ std::string (status.message ()));
76+ }
77+ },
78+ nb::call_guard<nb::gil_scoped_release>())
79+ .def (
80+ " H2dChunk" ,
81+ [](WeightSynchronizer& self, size_t shard_idx,
82+ size_t host_offset_bytes, size_t device_offset_bytes,
83+ size_t size_bytes) {
84+ auto status_or_future = self.H2dChunk (
85+ shard_idx, host_offset_bytes, device_offset_bytes, size_bytes);
86+ if (!status_or_future.ok ()) {
87+ throw std::runtime_error (
88+ " WeightSynchronizer H2dChunk failed: " +
89+ std::string (status_or_future.status ().message ()));
90+ }
91+ absl::Status status = status_or_future.value ().Await ().status ();
92+ if (!status.ok ()) {
93+ throw std::runtime_error (
94+ " WeightSynchronizer H2dChunk copy failed: " +
95+ std::string (status.message ()));
96+ }
97+ },
98+ nb::arg (" shard_idx" ), nb::arg (" host_offset_bytes" ),
99+ nb::arg (" device_offset_bytes" ), nb::arg (" size_bytes" ),
100+ nb::call_guard<nb::gil_scoped_release>())
101+ .def (
102+ " PullWeightsChunk" ,
103+ [](WeightSynchronizer& self, const std::string& source,
104+ size_t src_shard_idx, size_t src_offset_bytes,
105+ size_t dst_shard_idx, size_t dst_offset_bytes, size_t size_bytes) {
106+ absl::Status s = self.PullWeightsChunk (
107+ source, src_shard_idx, src_offset_bytes, dst_shard_idx,
108+ dst_offset_bytes, size_bytes);
109+ if (!s.ok ()) {
110+ throw std::runtime_error (
111+ " WeightSynchronizer PullWeightsChunk failed: " +
112+ std::string (s.message ()));
113+ }
114+ },
115+ nb::arg (" source" ), nb::arg (" src_shard_idx" ),
116+ nb::arg (" src_offset_bytes" ), nb::arg (" dst_shard_idx" ),
117+ nb::arg (" dst_offset_bytes" ), nb::arg (" size_bytes" ),
118+ nb::call_guard<nb::gil_scoped_release>())
119+ .def (
120+ " get_host_buffer" ,
121+ [](WeightSynchronizer& self, size_t layer_idx, size_t shard_idx) {
122+ const uint8_t * ptr = self.GetHostBufferPtr (layer_idx, shard_idx);
123+ if (!ptr) {
124+ throw std::runtime_error (" Invalid layer or shard index" );
125+ }
126+ size_t size = self.slice_byte_size () + 256 * 1024 ;
127+ return at::from_blob (const_cast <uint8_t *>(ptr),
128+ {static_cast <int64_t >(size)}, at::kByte );
129+ },
130+ nb::arg (" layer_idx" ) = 0 , nb::arg (" shard_idx" ) = 0 )
63131 .def_prop_ro (" local_port" , &WeightSynchronizer::local_port)
64132 .def_prop_ro (" num_layers" , &WeightSynchronizer::num_layers)
65133 .def_prop_ro (" num_shards" , &WeightSynchronizer::num_shards)
0 commit comments