Skip to content

Commit 7d28768

Browse files
ezyangfacebook-github-bot
authored andcommitted
Revert D5689636: Add RpcAgentTestFixture to extract duplicate code
Test Plan: revert-hammer Differential Revision: D5689636 Original commit changeset: f35eea1359ad fbshipit-source-id: 31928fce5e96b3beceefbc9a03f54769f10b7e1a
1 parent 1dda818 commit 7d28768

6 files changed

+50
-50
lines changed

test/dist_autograd_test.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
import torch.distributed as dist
88
import torch.distributed.autograd as dist_autograd
99
import torch.distributed.rpc as rpc
10-
import dist_utils
11-
from dist_utils import dist_init
12-
from rpc_agent_test_fixture import RpcAgentTestFixture
10+
from dist_utils import INIT_METHOD_TEMPLATE, dist_init, TEST_CONFIG
1311

1412
import threading
1513

@@ -161,7 +159,7 @@ class ExecMode(Enum):
161159
@unittest.skipIf(
162160
not torch._six.PY3, "Pytorch distributed autograd package " "does not support python2"
163161
)
164-
class DistAutogradTest(RpcAgentTestFixture):
162+
class DistAutogradTest(object):
165163

166164
def _initialize_pg(self):
167165
# This is for tests using `dist.barrier`.
@@ -202,6 +200,14 @@ def _next_rank(self):
202200
def _check_rpc_done(self, rank_distance):
203201
_check_rpc_done(rank_distance)
204202

203+
@property
204+
def world_size(self):
205+
return 4
206+
207+
@property
208+
def init_method(self):
209+
return INIT_METHOD_TEMPLATE.format(file_name=self.file_name)
210+
205211
@dist_init
206212
def test_autograd_context(self):
207213
# Verify max possible id.
@@ -1049,7 +1055,7 @@ def test_backward_autograd_engine_error(self):
10491055
# Run backwards, and validate we receive an error.
10501056
dist_autograd.backward([val.sum()])
10511057

1052-
@unittest.skipIf(dist_utils.TEST_CONFIG.rpc_backend_name == "PROCESS_GROUP",
1058+
@unittest.skipIf(TEST_CONFIG.rpc_backend_name == "PROCESS_GROUP",
10531059
"Skipping this test temporarily since ProcessGroupAgent does not report errors on node failures")
10541060
@dist_init(clean_shutdown=False)
10551061
def test_backward_node_failure(self):
@@ -1220,7 +1226,7 @@ def _wait_backward_done():
12201226
while not DistAutogradTest._backward_done:
12211227
time.sleep(0.1)
12221228

1223-
@unittest.skipIf(dist_utils.TEST_CONFIG.rpc_backend_name == "PROCESS_GROUP",
1229+
@unittest.skipIf(TEST_CONFIG.rpc_backend_name == "PROCESS_GROUP",
12241230
"Skipping this test temporarily since ProcessGroupAgent " +
12251231
"does not report errors on node failures")
12261232
@dist_init(clean_shutdown=False)

test/dist_optimizer_test.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@
22

33
import unittest
44

5-
from dist_utils import dist_init
5+
from dist_utils import INIT_METHOD_TEMPLATE, dist_init
66
from torch import optim
77
from torch.distributed.optim import DistributedOptimizer
88
import torch
99
import torch.distributed.autograd as dist_autograd
1010
import torch.distributed.rpc as rpc
1111
import threading
12-
from rpc_agent_test_fixture import RpcAgentTestFixture
1312

1413

1514
class MyModule:
@@ -84,7 +83,17 @@ def rpc_async_method(method, obj_rref, *args, **kwargs):
8483
@unittest.skipIf(
8584
not torch._six.PY3, "Pytorch distributed optim does not support python2"
8685
)
87-
class DistOptimizerTest(RpcAgentTestFixture):
86+
class DistOptimizerTest(object):
87+
88+
@property
89+
def world_size(self):
90+
return 4
91+
92+
@property
93+
def init_method(self):
94+
return INIT_METHOD_TEMPLATE.format(
95+
file_name=self.file_name, rank=self.rank, world_size=self.world_size
96+
)
8897

8998
@dist_init()
9099
def test_dist_optim_exception(self):

test/dist_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def new_test_method(self, *arg, **kwargs):
8585
# Use enough 'num_send_recv_threads' until we fix https://github.com/pytorch/pytorch/issues/26359
8686
rpc.init_rpc(
8787
self_name="worker%d" % self.rank,
88-
backend=self.rpc_backend,
88+
backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
8989
init_method=self.init_method,
9090
self_rank=self.rank,
9191
worker_name_to_id=self.worker_name_to_id,

test/rpc_agent_test_fixture.py

-22
This file was deleted.

test/rpc_test.py

+21-15
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,14 @@
1212
import torch.distributed.rpc as rpc
1313
from torch.distributed.rpc import RRef
1414
from common_utils import load_tests
15-
import dist_utils
16-
from dist_utils import dist_init
15+
from dist_utils import INIT_METHOD_TEMPLATE, TEST_CONFIG, dist_init
1716
from torch.distributed.rpc.internal import PythonUDF, _internal_rpc_pickler
18-
from rpc_agent_test_fixture import RpcAgentTestFixture
1917

2018

2119
def requires_process_group_agent(message=""):
2220
def decorator(old_func):
2321
return unittest.skipUnless(
24-
dist_utils.TEST_CONFIG.rpc_backend_name == "PROCESS_GROUP", message
22+
TEST_CONFIG.rpc_backend_name == "PROCESS_GROUP", message
2523
)(old_func)
2624

2725
return decorator
@@ -212,7 +210,15 @@ def raise_func():
212210
sys.version_info < (3, 0),
213211
"Pytorch distributed rpc package " "does not support python2",
214212
)
215-
class RpcTest(RpcAgentTestFixture):
213+
class RpcTest(object):
214+
@property
215+
def world_size(self):
216+
return 4
217+
218+
@property
219+
def init_method(self):
220+
return INIT_METHOD_TEMPLATE.format(file_name=self.file_name)
221+
216222
@dist_init
217223
def test_worker_id(self):
218224
n = self.rank + 1
@@ -308,7 +314,7 @@ def test_duplicate_name(self):
308314
self.init_method, rank=self.rank, world_size=self.world_size
309315
))
310316
rpc._init_rpc_backend(
311-
backend=self.rpc_backend,
317+
backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
312318
store=store,
313319
self_name="duplicate_name",
314320
self_rank=self.rank,
@@ -320,7 +326,7 @@ def test_duplicate_name(self):
320326
def test_reinit(self):
321327
rpc.init_rpc(
322328
self_name="worker{}".format(self.rank),
323-
backend=self.rpc_backend,
329+
backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
324330
init_method=self.init_method,
325331
self_rank=self.rank,
326332
worker_name_to_id=self.worker_name_to_id,
@@ -342,7 +348,7 @@ def test_reinit(self):
342348
with self.assertRaisesRegex(RuntimeError, "is already initialized"):
343349
rpc.init_rpc(
344350
self_name="worker{}".format(self.rank),
345-
backend=self.rpc_backend,
351+
backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
346352
init_method=self.init_method,
347353
self_rank=self.rank,
348354
worker_name_to_id=self.worker_name_to_id,
@@ -356,7 +362,7 @@ def test_invalid_names(self):
356362
self.init_method, rank=self.rank, world_size=self.world_size
357363
))
358364
rpc._init_rpc_backend(
359-
backend=self.rpc_backend,
365+
backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
360366
store=store,
361367
self_name="abc*",
362368
self_rank=self.rank,
@@ -373,7 +379,7 @@ def test_invalid_names(self):
373379
self.init_method, rank=self.rank, world_size=self.world_size
374380
))
375381
rpc._init_rpc_backend(
376-
backend=self.rpc_backend,
382+
backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
377383
store=store,
378384
self_name=" ",
379385
self_rank=self.rank,
@@ -388,7 +394,7 @@ def test_invalid_names(self):
388394
self.init_method, rank=self.rank, world_size=self.world_size
389395
))
390396
rpc._init_rpc_backend(
391-
backend=self.rpc_backend,
397+
backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
392398
store=store,
393399
self_name="",
394400
self_rank=self.rank,
@@ -405,7 +411,7 @@ def test_invalid_names(self):
405411
self.init_method, rank=self.rank, world_size=self.world_size
406412
))
407413
rpc._init_rpc_backend(
408-
backend=self.rpc_backend,
414+
backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
409415
store=store,
410416
self_name="".join(["a" for i in range(500)]),
411417
self_rank=self.rank,
@@ -497,7 +503,7 @@ def test_join_rpc(self):
497503
# Initialize RPC.
498504
rpc.init_rpc(
499505
self_name="worker%d" % self.rank,
500-
backend=self.rpc_backend,
506+
backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
501507
init_method=self.init_method,
502508
self_rank=self.rank,
503509
worker_name_to_id=self.worker_name_to_id,
@@ -1078,7 +1084,7 @@ def test_get_rpc_timeout(self):
10781084
timeout = timedelta(seconds=1)
10791085
rpc.init_rpc(
10801086
self_name="worker{}".format(self.rank),
1081-
backend=self.rpc_backend,
1087+
backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
10821088
init_method=self.init_method,
10831089
self_rank=self.rank,
10841090
worker_name_to_id=self.worker_name_to_id,
@@ -1123,7 +1129,7 @@ def test_requires_process_group_agent_decorator(self):
11231129
def test_func():
11241130
return "expected result"
11251131

1126-
if dist_utils.TEST_CONFIG.rpc_backend_name == "PROCESS_GROUP":
1132+
if TEST_CONFIG.rpc_backend_name == "PROCESS_GROUP":
11271133
self.assertEqual(test_func(), "expected result")
11281134

11291135
def test_dist_init_decorator(self):

test/run_test.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@
7070
TESTS.extend([
7171
'rpc_spawn',
7272
'dist_autograd_spawn',
73-
'dist_optimizer_spawn',
7473
])
7574

7675
# skip < 3.6 b/c fstrings added in 3.6
@@ -81,18 +80,20 @@
8180

8281
WINDOWS_BLACKLIST = [
8382
'distributed',
83+
'rpc_fork',
8484
'rpc_spawn',
85+
'dist_autograd_fork',
8586
'dist_autograd_spawn',
86-
'dist_optimizer_spawn',
8787
]
8888

8989
ROCM_BLACKLIST = [
9090
'cpp_extensions',
9191
'distributed',
9292
'multiprocessing',
93+
'rpc_fork',
9394
'rpc_spawn',
95+
'dist_autograd_fork',
9496
'dist_autograd_spawn',
95-
'dist_optimizer_spawn',
9697
]
9798

9899
DISTRIBUTED_TESTS_CONFIG = {}

0 commit comments

Comments
 (0)