Skip to content

Commit b194b32

Browse files
datenglincopybara-github
authored andcommitted
Implemented the weight sync for torch tensors.
PiperOrigin-RevId: 926213635
1 parent 1c125d5 commit b194b32

4 files changed

Lines changed: 364 additions & 12 deletions

File tree

api/torch/weight_synchronizer.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,68 @@ def pull_weights(self, source: str) -> None:
5252
"""Inference server pulling current weights from the source peer coordinate E2E."""
5353
self._impl.PullWeights(source)
5454

55+
def d2h(self) -> None:
56+
"""Triggers asynchronous Device-to-Host (D2H) copy of current weights to Host buffer."""
57+
self._impl.D2h()
58+
59+
def pull_weights_chunk(
60+
self,
61+
source: str,
62+
src_shard_idx: int,
63+
src_offset_bytes: int,
64+
dst_shard_idx: int,
65+
dst_offset_bytes: int,
66+
size_bytes: int,
67+
) -> None:
68+
"""Inference server pulling a specific byte range directly from a source worker peer.
69+
70+
Args:
71+
source: "host:port" coordinate of the source peer.
72+
src_shard_idx: Target source device shard index to read.
73+
src_offset_bytes: Offset in bytes inside source shard staging buffer.
74+
dst_shard_idx: Local destination device shard index to write.
75+
dst_offset_bytes: Offset in bytes inside local destination staging buffer.
76+
size_bytes: Number of bytes to transfer.
77+
"""
78+
self._impl.PullWeightsChunk(
79+
source,
80+
src_shard_idx,
81+
src_offset_bytes,
82+
dst_shard_idx,
83+
dst_offset_bytes,
84+
size_bytes,
85+
)
86+
87+
def h2d_chunk(
88+
self,
89+
shard_idx: int,
90+
host_offset_bytes: int,
91+
device_offset_bytes: int,
92+
size_bytes: int,
93+
) -> None:
94+
"""Triggers asynchronous Host-to-Device (H2D) chunk copy directly to Device HBM.
95+
96+
Args:
97+
shard_idx: Target shard index.
98+
host_offset_bytes: Source offset in Host staging buffer.
99+
device_offset_bytes: Destination offset in Device memory.
100+
size_bytes: Number of bytes to copy.
101+
"""
102+
self._impl.H2dChunk(
103+
shard_idx, host_offset_bytes, device_offset_bytes, size_bytes
104+
)
105+
106+
def get_host_buffer(
107+
self, layer_idx: int = 0, shard_idx: int = 0
108+
) -> torch.Tensor:
109+
"""Returns a zero-copy Host-side CPU PyTorch Tensor view of the C++ staging buffer.
110+
111+
Args:
112+
layer_idx: Target layer index to fetch.
113+
shard_idx: Target shard index to fetch.
114+
"""
115+
return self._impl.get_host_buffer(layer_idx, shard_idx)
116+
55117
@property
56118
def local_port(self) -> Optional[int]:
57119
"""Returns assigned ephemeral listener port coordinates."""

frameworks/torch/weight_synchronizer_module.cc

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,74 @@ NB_MODULE(_weight_synchronizer, m) {
6060
}
6161
},
6262
nb::arg("source"), nb::call_guard<nb::gil_scoped_release>())
63+
.def(
64+
"D2h",
65+
[](WeightSynchronizer& self) {
66+
auto status_or_future = self.D2h();
67+
if (!status_or_future.ok()) {
68+
throw std::runtime_error(
69+
"WeightSynchronizer D2H failed: " +
70+
std::string(status_or_future.status().message()));
71+
}
72+
absl::Status status = status_or_future.value().Await().status();
73+
if (!status.ok()) {
74+
throw std::runtime_error("WeightSynchronizer D2H copy failed: " +
75+
std::string(status.message()));
76+
}
77+
},
78+
nb::call_guard<nb::gil_scoped_release>())
79+
.def(
80+
"H2dChunk",
81+
[](WeightSynchronizer& self, size_t shard_idx,
82+
size_t host_offset_bytes, size_t device_offset_bytes,
83+
size_t size_bytes) {
84+
auto status_or_future = self.H2dChunk(
85+
shard_idx, host_offset_bytes, device_offset_bytes, size_bytes);
86+
if (!status_or_future.ok()) {
87+
throw std::runtime_error(
88+
"WeightSynchronizer H2dChunk failed: " +
89+
std::string(status_or_future.status().message()));
90+
}
91+
absl::Status status = status_or_future.value().Await().status();
92+
if (!status.ok()) {
93+
throw std::runtime_error(
94+
"WeightSynchronizer H2dChunk copy failed: " +
95+
std::string(status.message()));
96+
}
97+
},
98+
nb::arg("shard_idx"), nb::arg("host_offset_bytes"),
99+
nb::arg("device_offset_bytes"), nb::arg("size_bytes"),
100+
nb::call_guard<nb::gil_scoped_release>())
101+
.def(
102+
"PullWeightsChunk",
103+
[](WeightSynchronizer& self, const std::string& source,
104+
size_t src_shard_idx, size_t src_offset_bytes,
105+
size_t dst_shard_idx, size_t dst_offset_bytes, size_t size_bytes) {
106+
absl::Status s = self.PullWeightsChunk(
107+
source, src_shard_idx, src_offset_bytes, dst_shard_idx,
108+
dst_offset_bytes, size_bytes);
109+
if (!s.ok()) {
110+
throw std::runtime_error(
111+
"WeightSynchronizer PullWeightsChunk failed: " +
112+
std::string(s.message()));
113+
}
114+
},
115+
nb::arg("source"), nb::arg("src_shard_idx"),
116+
nb::arg("src_offset_bytes"), nb::arg("dst_shard_idx"),
117+
nb::arg("dst_offset_bytes"), nb::arg("size_bytes"),
118+
nb::call_guard<nb::gil_scoped_release>())
119+
.def(
120+
"get_host_buffer",
121+
[](WeightSynchronizer& self, size_t layer_idx, size_t shard_idx) {
122+
const uint8_t* ptr = self.GetHostBufferPtr(layer_idx, shard_idx);
123+
if (!ptr) {
124+
throw std::runtime_error("Invalid layer or shard index");
125+
}
126+
size_t size = self.slice_byte_size() + 256 * 1024;
127+
return at::from_blob(const_cast<uint8_t*>(ptr),
128+
{static_cast<int64_t>(size)}, at::kByte);
129+
},
130+
nb::arg("layer_idx") = 0, nb::arg("shard_idx") = 0)
63131
.def_prop_ro("local_port", &WeightSynchronizer::local_port)
64132
.def_prop_ro("num_layers", &WeightSynchronizer::num_layers)
65133
.def_prop_ro("num_shards", &WeightSynchronizer::num_shards)

weight_sync/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ cc_library(
3939
"@com_google_absl//absl/status",
4040
"@com_google_absl//absl/status:status_macros",
4141
"@com_google_absl//absl/status:statusor",
42+
"@com_google_absl//absl/types:span",
43+
"@xla//xla/pjrt:pjrt_client",
4244
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
4345
"@xla//xla/pjrt/c:pjrt_c_api_raw_buffer_extension_hdrs",
4446
"@xla//xla/tsl/platform:errors",

0 commit comments

Comments
 (0)