Skip to content

Commit 1dda818

Browse files
ezyangfacebook-github-bot
authored andcommitted
Revert D18549919: Add RpcAgentOptions struct type, which bundles different required arguments for different RpcAgents
Test Plan: revert-hammer Differential Revision: D18549919 Original commit changeset: b9f3f1a41d1f fbshipit-source-id: 2d5e578d18c0725b59eb99a0e942fbf7fe3341ee
1 parent 861ef05 commit 1dda818

File tree

10 files changed

+120
-221
lines changed

10 files changed

+120
-221
lines changed

test/dist_utils.py

+12-18
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import threading
44
from functools import partial, wraps
5+
from os import getenv
56

67
import torch.distributed as dist
78
import torch.distributed.rpc as rpc
@@ -13,15 +14,15 @@
1314

1415

1516
class TestConfig:
16-
__slots__ = ["rpc_backend_name", "build_rpc_agent_options"]
17+
__slots__ = ["rpc_backend_name"]
1718

1819
def __init__(self, *args, **kwargs):
1920
assert len(args) == 0, "TestConfig only takes kwargs."
2021
for k, v in kwargs.items():
2122
setattr(self, k, v)
2223

2324

24-
TEST_CONFIG = TestConfig()
25+
TEST_CONFIG = TestConfig(rpc_backend_name=getenv("RPC_BACKEND_NAME", "PROCESS_GROUP"))
2526
INIT_METHOD_TEMPLATE = "file://{file_name}"
2627

2728

@@ -73,20 +74,22 @@ def dist_init(old_test_method=None, setup_rpc=True, clean_shutdown=True):
7374
@wraps(old_test_method)
7475
def new_test_method(self, *arg, **kwargs):
7576
self.worker_id = self.rank
77+
self.worker_name_to_id = {
78+
"worker{}".format(rank): rank for rank in range(self.world_size)
79+
}
7680

7781
if setup_rpc:
7882
global _ALL_NODE_NAMES
79-
_ALL_NODE_NAMES = {
80-
"worker{}".format(rank) for rank in range(self.world_size)
81-
}
83+
_ALL_NODE_NAMES = self.worker_name_to_id.keys()
8284

85+
# Use enough 'num_send_recv_threads' until we fix https://github.com/pytorch/pytorch/issues/26359
8386
rpc.init_rpc(
84-
name="worker%d" % self.rank,
87+
self_name="worker%d" % self.rank,
8588
backend=self.rpc_backend,
8689
init_method=self.init_method,
87-
rank=self.rank,
88-
world_size=self.world_size,
89-
rpc_agent_options=self.rpc_agent_options,
90+
self_rank=self.rank,
91+
worker_name_to_id=self.worker_name_to_id,
92+
num_send_recv_threads=16,
9093
)
9194

9295
return_value = old_test_method(self, *arg, **kwargs)
@@ -128,12 +131,3 @@ def new_test_method(self, *arg, **kwargs):
128131
return return_value
129132

130133
return new_test_method
131-
132-
133-
# Set PROCESS_GROUP as the default RPC backend.
134-
TEST_CONFIG.rpc_backend_name = "PROCESS_GROUP"
135-
TEST_CONFIG.build_rpc_agent_options = lambda test_object: rpc.backend_registry.construct_rpc_agent_options(
136-
test_object.rpc_backend,
137-
# Use enough 'num_send_recv_threads' until we fix https://github.com/pytorch/pytorch/issues/26359
138-
num_send_recv_threads=16,
139-
)

test/rpc_test.py

+40-61
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,8 @@ def decorator(old_func):
3030
VALUE_FUTURE = concurrent.futures.Future()
3131

3232

33-
def _stub_construct_rpc_agent_options_handler(
34-
**kwargs
35-
):
36-
return mock.Mock() # RpcAgentOptions.
37-
38-
39-
def _stub_start_rpc_backend_handler(
40-
store, name, rank, world_size, rpc_agent_options
33+
def stub_start_rpc_backend_handler(
34+
store, self_name, self_rank, worker_name_to_id, *args, **kwargs
4135
):
4236
return mock.Mock() # RpcAgent.
4337

@@ -288,27 +282,22 @@ def test_register_rpc_backend_and_start_rpc_backend(
288282
backend_name = "stub_backend"
289283

290284
backend = rpc.backend_registry.register_backend(
291-
backend_name,
292-
_stub_construct_rpc_agent_options_handler,
293-
_stub_start_rpc_backend_handler,
285+
backend_name, stub_start_rpc_backend_handler
294286
)
295287

296288
with self.assertRaisesRegex(
297289
RuntimeError, "^RPC backend .+: already registered$"
298290
):
299-
backend = rpc.backend_registry.register_backend(
300-
backend_name,
301-
_stub_construct_rpc_agent_options_handler,
302-
_stub_start_rpc_backend_handler,
291+
rpc.backend_registry.register_backend(
292+
backend_name, stub_start_rpc_backend_handler
303293
)
304294

305295
rpc.init_rpc(
306-
name="worker1",
296+
self_name="worker1",
307297
backend=backend,
308298
init_method=self.init_method,
309-
rank=self.rank,
310-
world_size=self.world_size,
311-
rpc_agent_options=self.rpc_agent_options,
299+
self_rank=self.rank,
300+
worker_name_to_id=self.worker_name_to_id,
312301
)
313302

314303
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
@@ -321,22 +310,20 @@ def test_duplicate_name(self):
321310
rpc._init_rpc_backend(
322311
backend=self.rpc_backend,
323312
store=store,
324-
name="duplicate_name",
325-
rank=self.rank,
326-
world_size=self.world_size,
327-
rpc_agent_options=self.rpc_agent_options,
313+
self_name="duplicate_name",
314+
self_rank=self.rank,
315+
worker_name_to_id=self.worker_name_to_id,
328316
)
329317
rpc.join_rpc()
330318

331319
@dist_init(setup_rpc=False)
332320
def test_reinit(self):
333321
rpc.init_rpc(
334-
name="worker{}".format(self.rank),
322+
self_name="worker{}".format(self.rank),
335323
backend=self.rpc_backend,
336324
init_method=self.init_method,
337-
rank=self.rank,
338-
world_size=self.world_size,
339-
rpc_agent_options=self.rpc_agent_options,
325+
self_rank=self.rank,
326+
worker_name_to_id=self.worker_name_to_id,
340327
)
341328

342329
# This is for the below `dist.barrier`.
@@ -354,12 +341,11 @@ def test_reinit(self):
354341

355342
with self.assertRaisesRegex(RuntimeError, "is already initialized"):
356343
rpc.init_rpc(
357-
name="worker{}".format(self.rank),
344+
self_name="worker{}".format(self.rank),
358345
backend=self.rpc_backend,
359346
init_method=self.init_method,
360-
rank=self.rank,
361-
world_size=self.world_size,
362-
rpc_agent_options=self.rpc_agent_options,
347+
self_rank=self.rank,
348+
worker_name_to_id=self.worker_name_to_id,
363349
)
364350
rpc.join_rpc()
365351

@@ -372,10 +358,10 @@ def test_invalid_names(self):
372358
rpc._init_rpc_backend(
373359
backend=self.rpc_backend,
374360
store=store,
375-
name="abc*",
376-
rank=self.rank,
377-
world_size=self.world_size,
378-
rpc_agent_options=self.rpc_agent_options,
361+
self_name="abc*",
362+
self_rank=self.rank,
363+
worker_name_to_id=self.worker_name_to_id,
364+
num_send_recv_threads=16,
379365
)
380366

381367
base_file_name = self.file_name
@@ -389,10 +375,10 @@ def test_invalid_names(self):
389375
rpc._init_rpc_backend(
390376
backend=self.rpc_backend,
391377
store=store,
392-
name=" ",
393-
rank=self.rank,
394-
world_size=self.world_size,
395-
rpc_agent_options=self.rpc_agent_options,
378+
self_name=" ",
379+
self_rank=self.rank,
380+
worker_name_to_id=self.worker_name_to_id,
381+
num_send_recv_threads=16,
396382
)
397383

398384
# Use a different file path for FileStore to avoid rendezvous mismatch.
@@ -404,10 +390,10 @@ def test_invalid_names(self):
404390
rpc._init_rpc_backend(
405391
backend=self.rpc_backend,
406392
store=store,
407-
name="",
408-
rank=self.rank,
409-
world_size=self.world_size,
410-
rpc_agent_options=self.rpc_agent_options,
393+
self_name="",
394+
self_rank=self.rank,
395+
worker_name_to_id=self.worker_name_to_id,
396+
num_send_recv_threads=16,
411397
)
412398

413399
# Use a different file path for FileStore to avoid rendezvous mismatch.
@@ -421,10 +407,10 @@ def test_invalid_names(self):
421407
rpc._init_rpc_backend(
422408
backend=self.rpc_backend,
423409
store=store,
424-
name="".join(["a" for i in range(500)]),
425-
rank=self.rank,
426-
world_size=self.world_size,
427-
rpc_agent_options=self.rpc_agent_options,
410+
self_name="".join(["a" for i in range(500)]),
411+
self_rank=self.rank,
412+
worker_name_to_id=self.worker_name_to_id,
413+
num_send_recv_threads=16,
428414
)
429415

430416
from torch.distributed.rpc.api import _agent
@@ -510,12 +496,11 @@ def test_multi_rpc(self):
510496
def test_join_rpc(self):
511497
# Initialize RPC.
512498
rpc.init_rpc(
513-
name="worker%d" % self.rank,
499+
self_name="worker%d" % self.rank,
514500
backend=self.rpc_backend,
515501
init_method=self.init_method,
516-
rank=self.rank,
517-
world_size=self.world_size,
518-
rpc_agent_options=self.rpc_agent_options,
502+
self_rank=self.rank,
503+
worker_name_to_id=self.worker_name_to_id,
519504
)
520505

521506
n = self.rank + 1
@@ -1091,19 +1076,13 @@ def test_call_method_on_rref(self):
10911076
@dist_init(setup_rpc=False)
10921077
def test_get_rpc_timeout(self):
10931078
timeout = timedelta(seconds=1)
1094-
1095-
# A new `RpcAgentOptions` is constructed
1096-
# when accessing `self.rpc_agent_options`.
1097-
rpc_agent_options = self.rpc_agent_options
1098-
rpc_agent_options.rpc_timeout = timeout
1099-
11001079
rpc.init_rpc(
1101-
name="worker{}".format(self.rank),
1080+
self_name="worker{}".format(self.rank),
11021081
backend=self.rpc_backend,
11031082
init_method=self.init_method,
1104-
rank=self.rank,
1105-
world_size=self.world_size,
1106-
rpc_agent_options=rpc_agent_options,
1083+
self_rank=self.rank,
1084+
worker_name_to_id=self.worker_name_to_id,
1085+
rpc_timeout=timeout
11071086
)
11081087
set_timeout = rpc.get_rpc_timeout()
11091088
self.assertEqual(timeout, set_timeout)

torch/csrc/distributed/rpc/init.cpp

-11
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,6 @@ PyObject* rpc_init(PyObject* /* unused */) {
3434

3535
auto module = py::handle(rpc_module).cast<py::module>();
3636

37-
auto rpcAgentOptions =
38-
shared_ptr_class_<RpcAgentOptions>(module, "RpcAgentOptions")
39-
.def_readwrite("rpc_timeout", &RpcAgentOptions::rpcTimeout);
40-
4137
auto workerInfo =
4238
shared_ptr_class_<WorkerInfo>(module, "WorkerInfo")
4339
.def_readonly("name", &WorkerInfo::name_)
@@ -102,13 +98,6 @@ PyObject* rpc_init(PyObject* /* unused */) {
10298
[&](FutureMessage& fut) { return toPyObj(fut.wait()); },
10399
py::call_guard<py::gil_scoped_release>());
104100

105-
shared_ptr_class_<ProcessGroupRpcAgentOptions>(
106-
module, "ProcessGroupRpcAgentOptions", rpcAgentOptions)
107-
.def(py::init<>())
108-
.def_readwrite(
109-
"num_send_recv_threads",
110-
&ProcessGroupRpcAgentOptions::numSendRecvThreads);
111-
112101
shared_ptr_class_<ProcessGroupAgent>(module, "ProcessGroupAgent", rpcAgent)
113102
.def(
114103
py::init<

torch/csrc/distributed/rpc/process_group_agent.h

-5
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,6 @@ namespace torch {
1313
namespace distributed {
1414
namespace rpc {
1515

16-
struct ProcessGroupRpcAgentOptions : public RpcAgentOptions {
17-
ProcessGroupRpcAgentOptions() noexcept = default;
18-
int numSendRecvThreads;
19-
};
20-
2116
// SendWork and RecvWork will be put into a task queue, and later picked up by
2217
// worker threads from the same ThreadPool.
2318
struct SendWork {

torch/csrc/distributed/rpc/rpc_agent.h

-5
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,6 @@ namespace torch {
1212
namespace distributed {
1313
namespace rpc {
1414

15-
struct RpcAgentOptions {
16-
RpcAgentOptions() noexcept = default;
17-
std::chrono::milliseconds rpcTimeout;
18-
};
19-
2015
// A globally unique ID to identify an RpcAgent
2116
struct TORCH_API WorkerInfo {
2217
WorkerInfo(std::string name, int id)

torch/distributed/rendezvous.py

-11
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
except ImportError:
44
from urlparse import urlparse
55

6-
import torch._six as six
7-
import numbers
86
import os
97
from . import FileStore, TCPStore
108

@@ -44,15 +42,6 @@ def register_rendezvous_handler(scheme, handler):
4442

4543

4644
def rendezvous(url, rank=-1, world_size=-1, **kwargs):
47-
if not isinstance(url, six.string_classes):
48-
raise RuntimeError("`url` must be a string. {}: {}".format(type(url), url))
49-
50-
if not isinstance(rank, numbers.Integral):
51-
raise RuntimeError("`rank` must be an integer. {}".format(rank))
52-
53-
if not isinstance(world_size, numbers.Integral):
54-
raise RuntimeError("`world_size` must be an integer. {}".format(world_size))
55-
5645
# Append node-specific arguments.
5746
if rank != -1 or world_size != -1:
5847
assert (

0 commit comments

Comments
 (0)