Skip to content

Commit 1bbbb1e

Browse files
authored
Merge pull request #745 from exo-explore/grpcio1.70.0
Grpcio1.70.0
2 parents 52a2164 + 4081305 commit 1bbbb1e

File tree

3 files changed

+24
-15
lines changed

3 files changed

+24
-15
lines changed

exo/networking/grpc/grpc_peer_handle.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,16 @@ def __init__(self, _id: str, address: str, desc: str, device_capabilities: Devic
2929
self.channel = None
3030
self.stub = None
3131
self.channel_options = [
32-
("grpc.max_metadata_size", 64 * 1024 * 1024),
32+
("grpc.max_metadata_size", 32 * 1024 * 1024),
3333
("grpc.max_receive_message_length", 256 * 1024 * 1024),
3434
("grpc.max_send_message_length", 256 * 1024 * 1024),
3535
("grpc.max_concurrent_streams", 100),
3636
("grpc.http2.min_time_between_pings_ms", 10000),
37-
("grpc.keepalive_time_ms", 20000),
38-
("grpc.keepalive_timeout_ms", 10000),
37+
("grpc.keepalive_time_ms", 10000),
38+
("grpc.keepalive_timeout_ms", 5000),
3939
("grpc.keepalive_permit_without_calls", 1),
4040
("grpc.http2.max_pings_without_data", 0),
41+
("grpc.http2.min_ping_interval_without_data_ms", 5000),
4142
("grpc.tcp_nodelay", 1),
4243
("grpc.optimization_target", "throughput"),
4344
]
@@ -55,14 +56,13 @@ def device_capabilities(self) -> DeviceCapabilities:
5556
return self._device_capabilities
5657

5758
async def connect(self):
58-
if self.channel is None:
59-
self.channel = grpc.aio.insecure_channel(
60-
self.address,
61-
options=self.channel_options,
62-
compression=grpc.Compression.Gzip
63-
)
64-
self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
65-
await self.channel.channel_ready()
59+
self.channel = grpc.aio.insecure_channel(
60+
self.address,
61+
options=self.channel_options,
62+
compression=grpc.Compression.Gzip
63+
)
64+
self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
65+
await asyncio.wait_for(self.channel.channel_ready(), timeout=10.0)
6666

6767
async def is_connected(self) -> bool:
6868
return self.channel is not None and self.channel.get_state() == grpc.ChannelConnectivity.READY
@@ -74,7 +74,7 @@ async def disconnect(self):
7474
self.stub = None
7575

7676
async def _ensure_connected(self):
77-
if not await self.is_connected():
77+
if not (await self.is_connected()):
7878
try:
7979
await asyncio.wait_for(self.connect(), timeout=10.0)
8080
except asyncio.TimeoutError:
@@ -98,6 +98,7 @@ async def health_check(self) -> bool:
9898
return False
9999

100100
async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
101+
await self._ensure_connected()
101102
request = node_service_pb2.PromptRequest(
102103
prompt=prompt,
103104
shard=node_service_pb2.Shard(
@@ -112,6 +113,7 @@ async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional
112113
await self.stub.SendPrompt(request)
113114

114115
async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
116+
await self._ensure_connected()
115117
request = node_service_pb2.TensorRequest(
116118
shard=node_service_pb2.Shard(
117119
model_id=shard.model_id,
@@ -131,6 +133,7 @@ async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: O
131133
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
132134

133135
async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]:
136+
await self._ensure_connected()
134137
request = node_service_pb2.ExampleRequest(
135138
shard=node_service_pb2.Shard(
136139
model_id=shard.model_id,
@@ -153,6 +156,7 @@ async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarr
153156
return loss
154157

155158
async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
159+
await self._ensure_connected()
156160
request = node_service_pb2.TensorRequest(
157161
shard=node_service_pb2.Shard(
158162
model_id=shard.model_id,
@@ -171,6 +175,7 @@ async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional
171175
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
172176

173177
async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
178+
await self._ensure_connected()
174179
request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth)
175180
response = await self.stub.CollectTopology(request)
176181
topology = Topology()
@@ -185,6 +190,7 @@ async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
185190
return topology
186191

187192
async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
193+
await self._ensure_connected()
188194
tensor = None
189195
if isinstance(result, np.ndarray):
190196
tensor = node_service_pb2.Tensor(tensor_data=result.tobytes(), shape=result.shape, dtype=str(result.dtype))
@@ -193,8 +199,9 @@ async def send_result(self, request_id: str, result: List[int], is_finished: boo
193199
await self.stub.SendResult(request)
194200

195201
async def send_opaque_status(self, request_id: str, status: str) -> None:
202+
await self._ensure_connected()
196203
request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
197-
await self.stub.SendOpaqueStatus(request)
204+
await asyncio.wait_for(self.stub.SendOpaqueStatus(request), timeout=10.0)
198205

199206
def serialize_inference_state(self, inference_state: dict) -> node_service_pb2.InferenceState:
200207
proto_inference_state = node_service_pb2.InferenceState()

exo/networking/grpc/grpc_server.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ async def start(self) -> None:
4040
("grpc.max_concurrent_streams", 100),
4141
("grpc.tcp_nodelay", 1),
4242
("grpc.optimization_target", "throughput"),
43+
("grpc.keepalive_permit_without_calls", 1),
44+
("grpc.http2.max_concurrent_streams", 0), # Unlimited concurrent streams
4345
],
4446
)
4547
node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
"aiohttp==3.10.11",
1010
"aiohttp_cors==0.7.0",
1111
"aiofiles==24.1.0",
12-
"grpcio==1.68.0",
13-
"grpcio-tools==1.68.0",
12+
"grpcio==1.70.0",
13+
"grpcio-tools==1.70.0",
1414
"Jinja2==3.1.4",
1515
"numpy==2.0.0",
1616
"nuitka==2.5.1",

0 commit comments

Comments
 (0)