@@ -29,15 +29,16 @@ def __init__(self, _id: str, address: str, desc: str, device_capabilities: Devic
2929 self .channel = None
3030 self .stub = None
3131 self .channel_options = [
32- ("grpc.max_metadata_size" , 64 * 1024 * 1024 ),
32+ ("grpc.max_metadata_size" , 32 * 1024 * 1024 ),
3333 ("grpc.max_receive_message_length" , 256 * 1024 * 1024 ),
3434 ("grpc.max_send_message_length" , 256 * 1024 * 1024 ),
3535 ("grpc.max_concurrent_streams" , 100 ),
3636 ("grpc.http2.min_time_between_pings_ms" , 10000 ),
37- ("grpc.keepalive_time_ms" , 20000 ),
38- ("grpc.keepalive_timeout_ms" , 10000 ),
37+ ("grpc.keepalive_time_ms" , 10000 ),
38+ ("grpc.keepalive_timeout_ms" , 5000 ),
3939 ("grpc.keepalive_permit_without_calls" , 1 ),
4040 ("grpc.http2.max_pings_without_data" , 0 ),
41+ ("grpc.http2.min_ping_interval_without_data_ms" , 5000 ),
4142 ("grpc.tcp_nodelay" , 1 ),
4243 ("grpc.optimization_target" , "throughput" ),
4344 ]
@@ -55,14 +56,13 @@ def device_capabilities(self) -> DeviceCapabilities:
5556 return self ._device_capabilities
5657
5758 async def connect (self ):
58- if self .channel is None :
59- self .channel = grpc .aio .insecure_channel (
60- self .address ,
61- options = self .channel_options ,
62- compression = grpc .Compression .Gzip
63- )
64- self .stub = node_service_pb2_grpc .NodeServiceStub (self .channel )
65- await self .channel .channel_ready ()
59+ self .channel = grpc .aio .insecure_channel (
60+ self .address ,
61+ options = self .channel_options ,
62+ compression = grpc .Compression .Gzip
63+ )
64+ self .stub = node_service_pb2_grpc .NodeServiceStub (self .channel )
65+ await asyncio .wait_for (self .channel .channel_ready (), timeout = 10.0 )
6666
6767 async def is_connected (self ) -> bool :
6868 return self .channel is not None and self .channel .get_state () == grpc .ChannelConnectivity .READY
@@ -74,7 +74,7 @@ async def disconnect(self):
7474 self .stub = None
7575
7676 async def _ensure_connected (self ):
77- if not await self .is_connected ():
77+ if not ( await self .is_connected () ):
7878 try :
7979 await asyncio .wait_for (self .connect (), timeout = 10.0 )
8080 except asyncio .TimeoutError :
@@ -98,6 +98,7 @@ async def health_check(self) -> bool:
9898 return False
9999
100100 async def send_prompt (self , shard : Shard , prompt : str , inference_state : Optional [dict ] = None , request_id : Optional [str ] = None ) -> Optional [np .array ]:
101+ await self ._ensure_connected ()
101102 request = node_service_pb2 .PromptRequest (
102103 prompt = prompt ,
103104 shard = node_service_pb2 .Shard (
@@ -112,6 +113,7 @@ async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional
112113 await self .stub .SendPrompt (request )
113114
114115 async def send_tensor (self , shard : Shard , tensor : np .ndarray , inference_state : Optional [dict ] = None , request_id : Optional [str ] = None ) -> Optional [np .array ]:
116+ await self ._ensure_connected ()
115117 request = node_service_pb2 .TensorRequest (
116118 shard = node_service_pb2 .Shard (
117119 model_id = shard .model_id ,
@@ -131,6 +133,7 @@ async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: O
131133 return np .frombuffer (response .tensor_data , dtype = np .dtype (response .dtype )).reshape (response .shape )
132134
133135 async def send_example (self , shard : Shard , example : np .ndarray , target : np .ndarray , length : np .ndarray , train : bool , request_id : Optional [str ] = None ) -> Optional [np .array ]:
136+ await self ._ensure_connected ()
134137 request = node_service_pb2 .ExampleRequest (
135138 shard = node_service_pb2 .Shard (
136139 model_id = shard .model_id ,
@@ -153,6 +156,7 @@ async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarr
153156 return loss
154157
155158 async def send_loss (self , shard : Shard , tensor : np .ndarray , request_id : Optional [str ] = None ) -> Optional [np .array ]:
159+ await self ._ensure_connected ()
156160 request = node_service_pb2 .TensorRequest (
157161 shard = node_service_pb2 .Shard (
158162 model_id = shard .model_id ,
@@ -171,6 +175,7 @@ async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional
171175 return np .frombuffer (response .tensor_data , dtype = np .dtype (response .dtype )).reshape (response .shape )
172176
173177 async def collect_topology (self , visited : set [str ], max_depth : int ) -> Topology :
178+ await self ._ensure_connected ()
174179 request = node_service_pb2 .CollectTopologyRequest (visited = visited , max_depth = max_depth )
175180 response = await self .stub .CollectTopology (request )
176181 topology = Topology ()
@@ -185,6 +190,7 @@ async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
185190 return topology
186191
187192 async def send_result (self , request_id : str , result : List [int ], is_finished : bool ) -> None :
193+ await self ._ensure_connected ()
188194 tensor = None
189195 if isinstance (result , np .ndarray ):
190196 tensor = node_service_pb2 .Tensor (tensor_data = result .tobytes (), shape = result .shape , dtype = str (result .dtype ))
@@ -193,8 +199,9 @@ async def send_result(self, request_id: str, result: List[int], is_finished: boo
193199 await self .stub .SendResult (request )
194200
195201 async def send_opaque_status (self , request_id : str , status : str ) -> None :
202+ await self ._ensure_connected ()
196203 request = node_service_pb2 .SendOpaqueStatusRequest (request_id = request_id , status = status )
197- await self .stub .SendOpaqueStatus (request )
204+ await asyncio . wait_for ( self .stub .SendOpaqueStatus (request ), timeout = 10.0 )
198205
199206 def serialize_inference_state (self , inference_state : dict ) -> node_service_pb2 .InferenceState :
200207 proto_inference_state = node_service_pb2 .InferenceState ()
0 commit comments