Skip to content

Commit 8351350

Browse files
rohan-varmafacebook-github-bot
authored andcommitted
poll for timed out futures in process group agent (pytorch#29601)
Summary: Pull Request resolved: pytorch#29601 Follow up from pytorch#28392. Adds a background thread to `ProcessGroupAgent` that polls for timed out RPCs at a pre-set interval, and marks them as completed with a timeout exception if they have timed out. Also deletes the futures from the corresponding maps `futures_` and `futureTimeouts`. Unit tests are added to ensure that timed out RPCs are appropriately cleaned up. Also adds a `shutdown` variable to process group agent to control the shutting down of this background thread, which can eventually be extended to use for controlling a clean shutdown of process group agent. ghstack-source-id: 9417513 Test Plan: Added unit tests Differential Revision: D18434215 fbshipit-source-id: c48abdb8759fe1447200ec66bb9d4b1c50ec4535
1 parent 21dc1d4 commit 8351350

File tree

5 files changed

+240
-34
lines changed

5 files changed

+240
-34
lines changed

test/rpc_test.py

+34-6
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ def my_function(a, b, c):
125125
def my_tensor_function(a, b):
126126
return a + b
127127

128+
def my_sleep_func(seconds=1):
129+
import time
130+
time.sleep(seconds)
131+
128132

129133
def my_complex_tensor_function(list_input, tensor_class_input, dict_input):
130134
res = list_input[0]
@@ -1084,13 +1088,8 @@ def test_call_method_on_rref(self):
10841088

10851089
self.assertEqual(result, sum(vals))
10861090

1087-
@dist_init
1088-
def test_get_default_rpc_timeout(self):
1089-
timeout = rpc.get_rpc_timeout()
1090-
self.assertEqual(timeout, rpc.constants.DEFAULT_RPC_TIMEOUT)
1091-
10921091
@dist_init(setup_rpc=False)
1093-
def test_set_rpc_timeout(self):
1092+
def test_get_rpc_timeout(self):
10941093
timeout = timedelta(seconds=1)
10951094

10961095
# A new `RpcAgentOptions` is constructed
@@ -1110,6 +1109,35 @@ def test_set_rpc_timeout(self):
11101109
self.assertEqual(timeout, set_timeout)
11111110
rpc.join_rpc()
11121111

1112+
@dist_init
1113+
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
1114+
def test_rpc_timeouts(self):
1115+
dst_rank = (self.rank + 1) % self.world_size
1116+
rpc._set_rpc_timeout(timedelta(milliseconds=1))
1117+
# futures should time out and be marked with an exception indicating it as such.
1118+
futs = [rpc.rpc_async("worker{}".format(dst_rank), my_sleep_func, args=()) for _ in range(10)]
1119+
for fut in futs:
1120+
with self.assertRaisesRegex(RuntimeError, "RPC ran for more than"):
1121+
fut.wait()
1122+
1123+
# ensure that if a new timeout is set old futures don't time out but new ones do.
1124+
rpc._set_rpc_timeout(timedelta(seconds=200))
1125+
# create a longstanding RPC.
1126+
fut1 = rpc.rpc_async("worker{}".format(dst_rank), my_sleep_func, args=(1,))
1127+
# now, set a short timeout.
1128+
rpc._set_rpc_timeout(timedelta(milliseconds=1))
1129+
# f2 should time out, f should not.
1130+
fut2 = rpc.rpc_async("worker{}".format(dst_rank), my_sleep_func, args=(1,))
1131+
with self.assertRaises(RuntimeError):
1132+
fut2.wait()
1133+
fut1.wait()
1134+
1135+
# future should run to completion if the timeout is zero.
1136+
rpc._set_rpc_timeout(timedelta(seconds=0))
1137+
rpc.rpc_async("worker{}".format(dst_rank), my_sleep_func, args=()).wait()
1138+
1139+
# reset to default timeout so shutdown messages can process cleanly.
1140+
rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT)
11131141

11141142
def test_requires_process_group_agent_decorator(self):
11151143
@requires_process_group_agent("test_func did not run")

torch/csrc/distributed/rpc/init.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,16 @@ PyObject* rpc_init(PyObject* /* unused */) {
200200
`datetime.timedelta` instance indicating the RPC timeout.
201201
)");
202202

203+
module.def(
204+
"_set_rpc_timeout",
205+
[](const std::chrono::milliseconds& rpcTimeout) {
206+
RpcAgent::getDefaultRpcAgent()->setRpcTimeout(rpcTimeout);
207+
},
208+
R"(
209+
Set the timeout for all RPCs. If an RPC is not completed within this
210+
time, an exception indicating it has timed out will be raised.
211+
)");
212+
203213
Py_RETURN_TRUE;
204214
}
205215

torch/csrc/distributed/rpc/process_group_agent.cpp

+140-15
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ std::vector<int64_t> ProcessGroupAgent::MessageCounter::snapshot() {
8484

8585
//////////////////////// ProcessGroupAgent /////////////////////////////////
8686

87+
const std::chrono::milliseconds INFINITE_TIMEOUT =
88+
std::chrono::milliseconds::max();
89+
8790
void ProcessGroupAgent::collectNames() {
8891
const std::string& workerName = workerInfo_.name_;
8992
const auto worldSize = pg_->getSize();
@@ -124,6 +127,7 @@ ProcessGroupAgent::ProcessGroupAgent(
124127
WorkerInfo(std::move(workerName), pg->getRank()),
125128
c10::guts::make_unique<RequestCallbackImpl>(),
126129
rpcTimeout),
130+
shutdown_{false},
127131
pg_(std::move(pg)),
128132
sendCounts_(pg_->getSize()),
129133
recvCounts_(pg_->getSize()),
@@ -182,7 +186,12 @@ void ProcessGroupAgent::join() {
182186
// feed it a message or kill the thread.
183187
// 2. A GLOO process cannot send message to itself. (there is an ongoing
184188
// effort to fix this problem).
189+
shutdown_.store(true);
185190
sync();
191+
// This is needed in case no futures were created, otherwise the future
192+
// timeout watchdog would sleep forever.
193+
194+
futureTimeoutCV_.notify_one();
186195
std::unique_lock<std::mutex> lock(futureMutex_);
187196
futureCV_.wait(
188197
lock, [this] { return futures_.empty() && futureTimeouts_.empty(); });
@@ -193,6 +202,7 @@ void ProcessGroupAgent::join() {
193202
SendWork(allWorkerInfo_[dst], Message({}, {}, MessageType::SHUTDOWN)));
194203
threadPool_.waitWorkComplete();
195204
listenerThread_.join();
205+
futureTimeoutThread_.join();
196206
PythonRpcHandler::getInstance().cleanup();
197207
}
198208

@@ -260,6 +270,8 @@ void ProcessGroupAgent::sync() {
260270

261271
void ProcessGroupAgent::start() {
262272
listenerThread_ = std::thread(&ProcessGroupAgent::listenLoop, this);
273+
futureTimeoutThread_ =
274+
std::thread(&ProcessGroupAgent::pollTimedOutRPCs, this);
263275
}
264276

265277
std::shared_ptr<FutureMessage> ProcessGroupAgent::send(
@@ -281,9 +293,22 @@ std::shared_ptr<FutureMessage> ProcessGroupAgent::send(
281293
std::chrono::steady_clock::now().time_since_epoch());
282294
{
283295
std::lock_guard<std::mutex> lock{futureMutex_};
284-
futures_[requestId] = std::make_pair(future, futureStartTime);
296+
// Set infinite timeout if specified.
297+
auto timeout = rpcTimeout_.load();
298+
if (timeout.count() == 0) {
299+
timeout = INFINITE_TIMEOUT;
300+
}
301+
auto futureInfo = FutureInfo(future, futureStartTime, to.id_, timeout);
302+
futures_[requestId] = futureInfo;
303+
auto rpcEndTime = getRPCEndTime(futureInfo);
285304
// insert future into timeouts map to keep track of its timeout
286-
futureTimeouts_[futureStartTime].push_back(requestId);
305+
futureTimeouts_[rpcEndTime].push_back(requestId);
306+
// Signal the watchdog to monitor future timeouts if this is the first
307+
// future created or if an RPC with a shorter TTL has been created.
308+
if (futures_.size() == 1 ||
309+
futureTimeouts_.begin()->first == rpcEndTime) {
310+
futureTimeoutCV_.notify_one();
311+
}
287312
}
288313
message.setId(requestId);
289314
} else {
@@ -389,27 +414,35 @@ void ProcessGroupAgent::enqueueRecv(RecvWork work) {
389414
} else if (message.isResponse()) {
390415
auto id = message.id();
391416
std::shared_ptr<FutureMessage> fm = nullptr;
392-
std::chrono::milliseconds futureStartTime;
393-
{
394-
std::lock_guard<std::mutex> lock{futureMutex_};
395-
std::tie(fm, futureStartTime) = futures_[id];
396-
}
397-
// Not holding lock on markCompleted as this could run callbacks that
398-
// call agent_->send
399-
fm->markCompleted(std::move(message));
400417
{
401418
std::lock_guard<std::mutex> lock{futureMutex_};
419+
const auto& futureInfo = futures_.find(id);
420+
if (futureInfo == futures_.end()) {
421+
// Received a completion for a timed out future, drop the recv.
422+
// RecvCounts will not be incremented here, it will be incremented
423+
// by the sender who has determined the future has timed out.
424+
return;
425+
}
426+
427+
fm = futureInfo->second.future_;
428+
auto rpcEndTime = getRPCEndTime(futureInfo->second);
402429
futures_.erase(id);
403430
// look up the corresponding future by its time out and request ID,
404431
// and remove it from the timeouts map
405-
auto& futuresAtTime = futureTimeouts_[futureStartTime];
406-
futuresAtTime.erase(
407-
std::find(futuresAtTime.begin(), futuresAtTime.end(), id));
408-
if (futuresAtTime.size() == 0) {
432+
auto& futuresAtTime = futureTimeouts_[rpcEndTime];
433+
auto it = std::find(futuresAtTime.begin(), futuresAtTime.end(), id);
434+
TORCH_INTERNAL_ASSERT(
435+
it != futuresAtTime.end(),
436+
"Error: could not find future in futureTimeouts map, race condition.");
437+
futuresAtTime.erase(it);
438+
if (futuresAtTime.empty()) {
409439
// remove the key from futureTimeouts_
410-
futureTimeouts_.erase(futureStartTime);
440+
futureTimeouts_.erase(rpcEndTime);
411441
}
412442
}
443+
// Not holding lock on markCompleted as this could run callbacks that
444+
// call agent_->send
445+
fm->markCompleted(std::move(message));
413446
futureCV_.notify_all();
414447
} else {
415448
// TODO: pass the error back to the caller instead of crashing here.
@@ -449,6 +482,98 @@ void ProcessGroupAgent::listenLoop() {
449482
}
450483
}
451484

485+
void ProcessGroupAgent::pollTimedOutRPCs() {
486+
while (!shutdown_.load()) {
487+
std::chrono::milliseconds sleepTime;
488+
std::unique_lock<std::mutex> lock{futureMutex_};
489+
// Estimate amount of time the first future will time out in, and sleep
490+
// for that long.
491+
// if there are no futures or the first future's RPC timeout is set to 0
492+
// (meaning no timeout), then sleep for a set "infinity" time.
493+
if (futureTimeouts_.empty() ||
494+
futureTimeouts_.begin()->first == INFINITE_TIMEOUT) {
495+
sleepTime = INFINITE_TIMEOUT;
496+
} else {
497+
const auto minFutureExpirationTime = futureTimeouts_.begin()->first;
498+
const auto remainingTime = getRPCRemainingTime(minFutureExpirationTime);
499+
sleepTime = std::max(remainingTime, std::chrono::milliseconds(0));
500+
}
501+
502+
if (sleepTime == INFINITE_TIMEOUT) {
503+
futureTimeoutCV_.wait(lock);
504+
} else {
505+
futureTimeoutCV_.wait_for(lock, sleepTime);
506+
}
507+
508+
if (shutdown_.load()) {
509+
return;
510+
}
511+
512+
const auto timedOutFutures = processTimedOutFutures();
513+
514+
// Do not hold the lock while marking futures completed, as markCompleted()
515+
// could invoke callbacks.
516+
lock.unlock();
517+
for (const auto& timedOutFuture : timedOutFutures) {
518+
std::ostringstream ss;
519+
ss << "RPC ran for more than " << timedOutFuture.timeout_.count()
520+
<< " milliseconds and timed out.";
521+
const auto exceptionMsg = createExceptionResponse(
522+
Message({}, {}, MessageType::EXCEPTION), ss.str());
523+
timedOutFuture.future_->markCompleted(exceptionMsg);
524+
525+
const int dst = timedOutFuture.dstRank_;
526+
recvCounts_.increment(dst);
527+
futureCV_.notify_all();
528+
}
529+
}
530+
}
531+
532+
const std::vector<ProcessGroupAgent::FutureInfo> ProcessGroupAgent::
533+
processTimedOutFutures() {
534+
std::vector<FutureInfo> timedOutFutures;
535+
for (auto it = futureTimeouts_.begin(); it != futureTimeouts_.end();
536+
/* intentional no increment */) {
537+
const auto& endTime = it->first;
538+
const auto remainingTime = getRPCRemainingTime(endTime);
539+
540+
if (remainingTime.count() > 0) {
541+
// Since the futureTimeouts_ map is ordered by timeout, we don't need
542+
// to check the remaining futures.
543+
break;
544+
} else {
545+
const std::vector<int64_t>& futureIDs = it->second;
546+
for (const auto& futureID : futureIDs) {
547+
auto futureIt = futures_.find(futureID);
548+
TORCH_INTERNAL_ASSERT(
549+
futureIt != futures_.end(),
550+
"Race Condition - Expected future does not exist in map");
551+
const auto futInfo = futureIt->second;
552+
timedOutFutures.push_back(futInfo);
553+
futures_.erase(futureID);
554+
}
555+
it = futureTimeouts_.erase(it);
556+
}
557+
}
558+
return timedOutFutures;
559+
}
560+
561+
const std::chrono::milliseconds ProcessGroupAgent::getRPCRemainingTime(
562+
const std::chrono::milliseconds& rpcEndTime) const {
563+
const auto remainingTime =
564+
rpcEndTime -
565+
std::chrono::duration_cast<std::chrono::milliseconds>(
566+
std::chrono::steady_clock::now().time_since_epoch());
567+
return remainingTime;
568+
}
569+
570+
const std::chrono::milliseconds ProcessGroupAgent::getRPCEndTime(
571+
const FutureInfo& futureInfo) const {
572+
return futureInfo.timeout_ == INFINITE_TIMEOUT
573+
? INFINITE_TIMEOUT
574+
: futureInfo.startTime_ + futureInfo.timeout_;
575+
}
576+
452577
} // namespace rpc
453578
} // namespace distributed
454579
} // namespace torch

0 commit comments

Comments
 (0)