@@ -30,14 +30,8 @@ def decorator(old_func):
30
30
VALUE_FUTURE = concurrent .futures .Future ()
31
31
32
32
33
- def _stub_construct_rpc_agent_options_handler (
34
- ** kwargs
35
- ):
36
- return mock .Mock () # RpcAgentOptions.
37
-
38
-
39
- def _stub_start_rpc_backend_handler (
40
- store , name , rank , world_size , rpc_agent_options
33
+ def stub_start_rpc_backend_handler (
34
+ store , self_name , self_rank , worker_name_to_id , * args , ** kwargs
41
35
):
42
36
return mock .Mock () # RpcAgent.
43
37
@@ -288,27 +282,22 @@ def test_register_rpc_backend_and_start_rpc_backend(
288
282
backend_name = "stub_backend"
289
283
290
284
backend = rpc .backend_registry .register_backend (
291
- backend_name ,
292
- _stub_construct_rpc_agent_options_handler ,
293
- _stub_start_rpc_backend_handler ,
285
+ backend_name , stub_start_rpc_backend_handler
294
286
)
295
287
296
288
with self .assertRaisesRegex (
297
289
RuntimeError , "^RPC backend .+: already registered$"
298
290
):
299
- backend = rpc .backend_registry .register_backend (
300
- backend_name ,
301
- _stub_construct_rpc_agent_options_handler ,
302
- _stub_start_rpc_backend_handler ,
291
+ rpc .backend_registry .register_backend (
292
+ backend_name , stub_start_rpc_backend_handler
303
293
)
304
294
305
295
rpc .init_rpc (
306
- name = "worker1" ,
296
+ self_name = "worker1" ,
307
297
backend = backend ,
308
298
init_method = self .init_method ,
309
- rank = self .rank ,
310
- world_size = self .world_size ,
311
- rpc_agent_options = self .rpc_agent_options ,
299
+ self_rank = self .rank ,
300
+ worker_name_to_id = self .worker_name_to_id ,
312
301
)
313
302
314
303
@requires_process_group_agent ("PROCESS_GROUP rpc backend specific test, skip" )
@@ -321,22 +310,20 @@ def test_duplicate_name(self):
321
310
rpc ._init_rpc_backend (
322
311
backend = self .rpc_backend ,
323
312
store = store ,
324
- name = "duplicate_name" ,
325
- rank = self .rank ,
326
- world_size = self .world_size ,
327
- rpc_agent_options = self .rpc_agent_options ,
313
+ self_name = "duplicate_name" ,
314
+ self_rank = self .rank ,
315
+ worker_name_to_id = self .worker_name_to_id ,
328
316
)
329
317
rpc .join_rpc ()
330
318
331
319
@dist_init (setup_rpc = False )
332
320
def test_reinit (self ):
333
321
rpc .init_rpc (
334
- name = "worker{}" .format (self .rank ),
322
+ self_name = "worker{}" .format (self .rank ),
335
323
backend = self .rpc_backend ,
336
324
init_method = self .init_method ,
337
- rank = self .rank ,
338
- world_size = self .world_size ,
339
- rpc_agent_options = self .rpc_agent_options ,
325
+ self_rank = self .rank ,
326
+ worker_name_to_id = self .worker_name_to_id ,
340
327
)
341
328
342
329
# This is for the below `dist.barrier`.
@@ -354,12 +341,11 @@ def test_reinit(self):
354
341
355
342
with self .assertRaisesRegex (RuntimeError , "is already initialized" ):
356
343
rpc .init_rpc (
357
- name = "worker{}" .format (self .rank ),
344
+ self_name = "worker{}" .format (self .rank ),
358
345
backend = self .rpc_backend ,
359
346
init_method = self .init_method ,
360
- rank = self .rank ,
361
- world_size = self .world_size ,
362
- rpc_agent_options = self .rpc_agent_options ,
347
+ self_rank = self .rank ,
348
+ worker_name_to_id = self .worker_name_to_id ,
363
349
)
364
350
rpc .join_rpc ()
365
351
@@ -372,10 +358,10 @@ def test_invalid_names(self):
372
358
rpc ._init_rpc_backend (
373
359
backend = self .rpc_backend ,
374
360
store = store ,
375
- name = "abc*" ,
376
- rank = self .rank ,
377
- world_size = self .world_size ,
378
- rpc_agent_options = self . rpc_agent_options ,
361
+ self_name = "abc*" ,
362
+ self_rank = self .rank ,
363
+ worker_name_to_id = self .worker_name_to_id ,
364
+ num_send_recv_threads = 16 ,
379
365
)
380
366
381
367
base_file_name = self .file_name
@@ -389,10 +375,10 @@ def test_invalid_names(self):
389
375
rpc ._init_rpc_backend (
390
376
backend = self .rpc_backend ,
391
377
store = store ,
392
- name = " " ,
393
- rank = self .rank ,
394
- world_size = self .world_size ,
395
- rpc_agent_options = self . rpc_agent_options ,
378
+ self_name = " " ,
379
+ self_rank = self .rank ,
380
+ worker_name_to_id = self .worker_name_to_id ,
381
+ num_send_recv_threads = 16 ,
396
382
)
397
383
398
384
# Use a different file path for FileStore to avoid rendezvous mismatch.
@@ -404,10 +390,10 @@ def test_invalid_names(self):
404
390
rpc ._init_rpc_backend (
405
391
backend = self .rpc_backend ,
406
392
store = store ,
407
- name = "" ,
408
- rank = self .rank ,
409
- world_size = self .world_size ,
410
- rpc_agent_options = self . rpc_agent_options ,
393
+ self_name = "" ,
394
+ self_rank = self .rank ,
395
+ worker_name_to_id = self .worker_name_to_id ,
396
+ num_send_recv_threads = 16 ,
411
397
)
412
398
413
399
# Use a different file path for FileStore to avoid rendezvous mismatch.
@@ -421,10 +407,10 @@ def test_invalid_names(self):
421
407
rpc ._init_rpc_backend (
422
408
backend = self .rpc_backend ,
423
409
store = store ,
424
- name = "" .join (["a" for i in range (500 )]),
425
- rank = self .rank ,
426
- world_size = self .world_size ,
427
- rpc_agent_options = self . rpc_agent_options ,
410
+ self_name = "" .join (["a" for i in range (500 )]),
411
+ self_rank = self .rank ,
412
+ worker_name_to_id = self .worker_name_to_id ,
413
+ num_send_recv_threads = 16 ,
428
414
)
429
415
430
416
from torch .distributed .rpc .api import _agent
@@ -510,12 +496,11 @@ def test_multi_rpc(self):
510
496
def test_join_rpc (self ):
511
497
# Initialize RPC.
512
498
rpc .init_rpc (
513
- name = "worker%d" % self .rank ,
499
+ self_name = "worker%d" % self .rank ,
514
500
backend = self .rpc_backend ,
515
501
init_method = self .init_method ,
516
- rank = self .rank ,
517
- world_size = self .world_size ,
518
- rpc_agent_options = self .rpc_agent_options ,
502
+ self_rank = self .rank ,
503
+ worker_name_to_id = self .worker_name_to_id ,
519
504
)
520
505
521
506
n = self .rank + 1
@@ -1091,19 +1076,13 @@ def test_call_method_on_rref(self):
1091
1076
@dist_init (setup_rpc = False )
1092
1077
def test_get_rpc_timeout (self ):
1093
1078
timeout = timedelta (seconds = 1 )
1094
-
1095
- # A new `RpcAgentOptions` is constructed
1096
- # when accessing `self.rpc_agent_options`.
1097
- rpc_agent_options = self .rpc_agent_options
1098
- rpc_agent_options .rpc_timeout = timeout
1099
-
1100
1079
rpc .init_rpc (
1101
- name = "worker{}" .format (self .rank ),
1080
+ self_name = "worker{}" .format (self .rank ),
1102
1081
backend = self .rpc_backend ,
1103
1082
init_method = self .init_method ,
1104
- rank = self .rank ,
1105
- world_size = self .world_size ,
1106
- rpc_agent_options = rpc_agent_options ,
1083
+ self_rank = self .rank ,
1084
+ worker_name_to_id = self .worker_name_to_id ,
1085
+ rpc_timeout = timeout
1107
1086
)
1108
1087
set_timeout = rpc .get_rpc_timeout ()
1109
1088
self .assertEqual (timeout , set_timeout )
0 commit comments