Skip to content

Commit 0d2614d

Browse files
authored
Fix errors in hivemind.p2p and hivemind.compression (#565)
This PR: 1. Fixes warnings in hivemind.p2p destructors. 2. Makes bfloat16 serialization in hivemind.compression forward- and backward-compatible. The code before this PR (a) didn't work in torch < 1.13.0 (hivemind requires torch >= 1.9.0) and (b) led to warnings on torch >= 2.0. The new code works without warnings in all versions of PyTorch.
1 parent 6c3a46c commit 0d2614d

File tree

3 files changed

+14
-13
lines changed

3 files changed

+14
-13
lines changed

hivemind/compression/base.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,11 @@ def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: b
8787
dtype_name = str(tensor.dtype).lstrip("torch.")
8888
raw_data = tensor
8989
if tensor.dtype == torch.bfloat16:
90-
if USE_LEGACY_BFLOAT16:
90+
if USE_LEGACY_BFLOAT16: # legacy mode: convert to fp32
9191
raw_data = tensor.to(torch.float32)
92-
else:
93-
typed_storage = tensor.storage()
94-
storage = typed_storage.untyped() if hasattr(typed_storage, "untyped") else typed_storage._untyped()
95-
raw_data = torch.tensor(storage, dtype=torch.int8)
92+
else: # efficient mode: send bfloat16 data directly
93+
# reinterpret_cast to an arbitrary 2-byte type supported by numpy
94+
raw_data = tensor.view(torch.int16)
9695

9796
return runtime_pb2.Tensor(
9897
compression=self.compression_type,
@@ -106,13 +105,13 @@ def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
106105
shape = torch.Size(serialized_tensor.size)
107106
if serialized_tensor.dtype == "bfloat16":
108107
numel = shape.numel()
109-
if numel > 0 and len(serialized_tensor.buffer) // numel == 4: # legacy mode: convert to fp32
108+
if numel > 0 and len(serialized_tensor.buffer) // numel == 4:
110109
array = np.frombuffer(serialized_tensor.buffer, dtype=np.float32)
111110
tensor = torch.as_tensor(array, dtype=torch.bfloat16)
112-
else: # efficient mode: send bfloat16 data directly
113-
storage_type = torch.TypedStorage if hasattr(torch, "TypedStorage") else torch._TypedStorage
114-
storage = storage_type.from_buffer(serialized_tensor.buffer, byte_order="little", dtype=torch.bfloat16)
115-
tensor = torch.as_tensor(storage, dtype=torch.bfloat16)
111+
else:
112+
array = np.frombuffer(serialized_tensor.buffer, dtype=np.int16)
113+
# reinterpret_cast from an arbitrary 2-byte type supported by numpy
114+
tensor = torch.as_tensor(array).view(torch.bfloat16)
116115
else:
117116
array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype))
118117
tensor = torch.as_tensor(array)

hivemind/p2p/p2p_daemon.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -654,8 +654,9 @@ def _terminate(self) -> None:
654654

655655
self._alive = False
656656
if self._child is not None and self._child.returncode is None:
657-
self._child.terminate()
658-
logger.debug(f"Terminated p2pd with id = {self.peer_id}")
657+
with suppress(ProcessLookupError):
658+
self._child.terminate()
659+
logger.debug(f"Terminated p2pd with id = {self.peer_id}")
659660

660661
with suppress(FileNotFoundError):
661662
os.remove(self._daemon_listen_maddr["unix"])

hivemind/p2p/p2p_daemon_bindings/p2pclient.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ async def create(
4747
return client
4848

4949
def close(self) -> None:
50-
self.control.close()
50+
if self.control is not None:
51+
self.control.close()
5152

5253
def __del__(self):
5354
self.close()

0 commit comments

Comments
 (0)