12
12
import torch .distributed .rpc as rpc
13
13
from torch .distributed .rpc import RRef
14
14
from common_utils import load_tests
15
- import dist_utils
16
- from dist_utils import dist_init
15
+ from dist_utils import INIT_METHOD_TEMPLATE , TEST_CONFIG , dist_init
17
16
from torch .distributed .rpc .internal import PythonUDF , _internal_rpc_pickler
18
- from rpc_agent_test_fixture import RpcAgentTestFixture
19
17
20
18
21
19
def requires_process_group_agent (message = "" ):
22
20
def decorator (old_func ):
23
21
return unittest .skipUnless (
24
- dist_utils . TEST_CONFIG .rpc_backend_name == "PROCESS_GROUP" , message
22
+ TEST_CONFIG .rpc_backend_name == "PROCESS_GROUP" , message
25
23
)(old_func )
26
24
27
25
return decorator
@@ -212,7 +210,15 @@ def raise_func():
212
210
sys .version_info < (3 , 0 ),
213
211
"Pytorch distributed rpc package " "does not support python2" ,
214
212
)
215
- class RpcTest (RpcAgentTestFixture ):
213
+ class RpcTest (object ):
214
+ @property
215
+ def world_size (self ):
216
+ return 4
217
+
218
+ @property
219
+ def init_method (self ):
220
+ return INIT_METHOD_TEMPLATE .format (file_name = self .file_name )
221
+
216
222
@dist_init
217
223
def test_worker_id (self ):
218
224
n = self .rank + 1
@@ -308,7 +314,7 @@ def test_duplicate_name(self):
308
314
self .init_method , rank = self .rank , world_size = self .world_size
309
315
))
310
316
rpc ._init_rpc_backend (
311
- backend = self . rpc_backend ,
317
+ backend = rpc . backend_registry . BackendType [ TEST_CONFIG . rpc_backend_name ] ,
312
318
store = store ,
313
319
self_name = "duplicate_name" ,
314
320
self_rank = self .rank ,
@@ -320,7 +326,7 @@ def test_duplicate_name(self):
320
326
def test_reinit (self ):
321
327
rpc .init_rpc (
322
328
self_name = "worker{}" .format (self .rank ),
323
- backend = self . rpc_backend ,
329
+ backend = rpc . backend_registry . BackendType [ TEST_CONFIG . rpc_backend_name ] ,
324
330
init_method = self .init_method ,
325
331
self_rank = self .rank ,
326
332
worker_name_to_id = self .worker_name_to_id ,
@@ -342,7 +348,7 @@ def test_reinit(self):
342
348
with self .assertRaisesRegex (RuntimeError , "is already initialized" ):
343
349
rpc .init_rpc (
344
350
self_name = "worker{}" .format (self .rank ),
345
- backend = self . rpc_backend ,
351
+ backend = rpc . backend_registry . BackendType [ TEST_CONFIG . rpc_backend_name ] ,
346
352
init_method = self .init_method ,
347
353
self_rank = self .rank ,
348
354
worker_name_to_id = self .worker_name_to_id ,
@@ -356,7 +362,7 @@ def test_invalid_names(self):
356
362
self .init_method , rank = self .rank , world_size = self .world_size
357
363
))
358
364
rpc ._init_rpc_backend (
359
- backend = self . rpc_backend ,
365
+ backend = rpc . backend_registry . BackendType [ TEST_CONFIG . rpc_backend_name ] ,
360
366
store = store ,
361
367
self_name = "abc*" ,
362
368
self_rank = self .rank ,
@@ -373,7 +379,7 @@ def test_invalid_names(self):
373
379
self .init_method , rank = self .rank , world_size = self .world_size
374
380
))
375
381
rpc ._init_rpc_backend (
376
- backend = self . rpc_backend ,
382
+ backend = rpc . backend_registry . BackendType [ TEST_CONFIG . rpc_backend_name ] ,
377
383
store = store ,
378
384
self_name = " " ,
379
385
self_rank = self .rank ,
@@ -388,7 +394,7 @@ def test_invalid_names(self):
388
394
self .init_method , rank = self .rank , world_size = self .world_size
389
395
))
390
396
rpc ._init_rpc_backend (
391
- backend = self . rpc_backend ,
397
+ backend = rpc . backend_registry . BackendType [ TEST_CONFIG . rpc_backend_name ] ,
392
398
store = store ,
393
399
self_name = "" ,
394
400
self_rank = self .rank ,
@@ -405,7 +411,7 @@ def test_invalid_names(self):
405
411
self .init_method , rank = self .rank , world_size = self .world_size
406
412
))
407
413
rpc ._init_rpc_backend (
408
- backend = self . rpc_backend ,
414
+ backend = rpc . backend_registry . BackendType [ TEST_CONFIG . rpc_backend_name ] ,
409
415
store = store ,
410
416
self_name = "" .join (["a" for i in range (500 )]),
411
417
self_rank = self .rank ,
@@ -497,7 +503,7 @@ def test_join_rpc(self):
497
503
# Initialize RPC.
498
504
rpc .init_rpc (
499
505
self_name = "worker%d" % self .rank ,
500
- backend = self . rpc_backend ,
506
+ backend = rpc . backend_registry . BackendType [ TEST_CONFIG . rpc_backend_name ] ,
501
507
init_method = self .init_method ,
502
508
self_rank = self .rank ,
503
509
worker_name_to_id = self .worker_name_to_id ,
@@ -1078,7 +1084,7 @@ def test_get_rpc_timeout(self):
1078
1084
timeout = timedelta (seconds = 1 )
1079
1085
rpc .init_rpc (
1080
1086
self_name = "worker{}" .format (self .rank ),
1081
- backend = self . rpc_backend ,
1087
+ backend = rpc . backend_registry . BackendType [ TEST_CONFIG . rpc_backend_name ] ,
1082
1088
init_method = self .init_method ,
1083
1089
self_rank = self .rank ,
1084
1090
worker_name_to_id = self .worker_name_to_id ,
@@ -1123,7 +1129,7 @@ def test_requires_process_group_agent_decorator(self):
1123
1129
def test_func ():
1124
1130
return "expected result"
1125
1131
1126
- if dist_utils . TEST_CONFIG .rpc_backend_name == "PROCESS_GROUP" :
1132
+ if TEST_CONFIG .rpc_backend_name == "PROCESS_GROUP" :
1127
1133
self .assertEqual (test_func (), "expected result" )
1128
1134
1129
1135
def test_dist_init_decorator (self ):
0 commit comments