@@ -84,6 +84,9 @@ std::vector<int64_t> ProcessGroupAgent::MessageCounter::snapshot() {
84
84
85
85
// ////////////////////// ProcessGroupAgent /////////////////////////////////
86
86
87
+ const std::chrono::milliseconds INFINITE_TIMEOUT =
88
+ std::chrono::milliseconds::max ();
89
+
87
90
void ProcessGroupAgent::collectNames () {
88
91
const std::string& workerName = workerInfo_.name_ ;
89
92
const auto worldSize = pg_->getSize ();
@@ -124,6 +127,7 @@ ProcessGroupAgent::ProcessGroupAgent(
124
127
WorkerInfo (std::move(workerName), pg->getRank()),
125
128
c10::guts::make_unique<RequestCallbackImpl>(),
126
129
rpcTimeout),
130
+ shutdown_{false },
127
131
pg_ (std::move(pg)),
128
132
sendCounts_(pg_->getSize ()),
129
133
recvCounts_(pg_->getSize ()),
@@ -182,7 +186,12 @@ void ProcessGroupAgent::join() {
182
186
// feed it a message or kill the thread.
183
187
// 2. A GLOO process cannot send message to itself. (there is an ongoing
184
188
// effort to fix this problem).
189
+ shutdown_.store (true );
185
190
sync ();
191
+ // This is needed in case no futures were created, otherwise the future
192
+ // timeout watchdog would sleep forever.
193
+
194
+ futureTimeoutCV_.notify_one ();
186
195
std::unique_lock<std::mutex> lock (futureMutex_);
187
196
futureCV_.wait (
188
197
lock, [this ] { return futures_.empty () && futureTimeouts_.empty (); });
@@ -193,6 +202,7 @@ void ProcessGroupAgent::join() {
193
202
SendWork (allWorkerInfo_[dst], Message ({}, {}, MessageType::SHUTDOWN)));
194
203
threadPool_.waitWorkComplete ();
195
204
listenerThread_.join ();
205
+ futureTimeoutThread_.join ();
196
206
PythonRpcHandler::getInstance ().cleanup ();
197
207
}
198
208
@@ -260,6 +270,8 @@ void ProcessGroupAgent::sync() {
260
270
261
271
void ProcessGroupAgent::start () {
262
272
listenerThread_ = std::thread (&ProcessGroupAgent::listenLoop, this );
273
+ futureTimeoutThread_ =
274
+ std::thread (&ProcessGroupAgent::pollTimedOutRPCs, this );
263
275
}
264
276
265
277
std::shared_ptr<FutureMessage> ProcessGroupAgent::send (
@@ -281,9 +293,22 @@ std::shared_ptr<FutureMessage> ProcessGroupAgent::send(
281
293
std::chrono::steady_clock::now ().time_since_epoch ());
282
294
{
283
295
std::lock_guard<std::mutex> lock{futureMutex_};
284
- futures_[requestId] = std::make_pair (future, futureStartTime);
296
+ // Set infinite timeout if specified.
297
+ auto timeout = rpcTimeout_.load ();
298
+ if (timeout.count () == 0 ) {
299
+ timeout = INFINITE_TIMEOUT;
300
+ }
301
+ auto futureInfo = FutureInfo (future, futureStartTime, to.id_ , timeout);
302
+ futures_[requestId] = futureInfo;
303
+ auto rpcEndTime = getRPCEndTime (futureInfo);
285
304
// insert future into timeouts map to keep track of its timeout
286
- futureTimeouts_[futureStartTime].push_back (requestId);
305
+ futureTimeouts_[rpcEndTime].push_back (requestId);
306
+ // Signal the watchdog to monitor future timeouts if this is the first
307
+ // future created or if an RPC with a shorter TTL has been created.
308
+ if (futures_.size () == 1 ||
309
+ futureTimeouts_.begin ()->first == rpcEndTime) {
310
+ futureTimeoutCV_.notify_one ();
311
+ }
287
312
}
288
313
message.setId (requestId);
289
314
} else {
@@ -389,27 +414,35 @@ void ProcessGroupAgent::enqueueRecv(RecvWork work) {
389
414
} else if (message.isResponse ()) {
390
415
auto id = message.id ();
391
416
std::shared_ptr<FutureMessage> fm = nullptr ;
392
- std::chrono::milliseconds futureStartTime;
393
- {
394
- std::lock_guard<std::mutex> lock{futureMutex_};
395
- std::tie (fm, futureStartTime) = futures_[id];
396
- }
397
- // Not holding lock on markCompleted as this could run callbacks that
398
- // call agent_->send
399
- fm->markCompleted (std::move (message));
400
417
{
401
418
std::lock_guard<std::mutex> lock{futureMutex_};
419
+ const auto & futureInfo = futures_.find (id);
420
+ if (futureInfo == futures_.end ()) {
421
+ // Received a completion for a timed out future, drop the recv.
422
+ // RecvCounts will not be incremented here, it will be incremented
423
+ // by the sender who has determined the future has timed out.
424
+ return ;
425
+ }
426
+
427
+ fm = futureInfo->second .future_ ;
428
+ auto rpcEndTime = getRPCEndTime (futureInfo->second );
402
429
futures_.erase (id);
403
430
// look up the corresponding future by its time out and request ID,
404
431
// and remove it from the timeouts map
405
- auto & futuresAtTime = futureTimeouts_[futureStartTime];
406
- futuresAtTime.erase (
407
- std::find (futuresAtTime.begin (), futuresAtTime.end (), id));
408
- if (futuresAtTime.size () == 0 ) {
432
+ auto & futuresAtTime = futureTimeouts_[rpcEndTime];
433
+ auto it = std::find (futuresAtTime.begin (), futuresAtTime.end (), id);
434
+ TORCH_INTERNAL_ASSERT (
435
+ it != futuresAtTime.end (),
436
+ " Error: could not find future in futureTimeouts map, race condition." );
437
+ futuresAtTime.erase (it);
438
+ if (futuresAtTime.empty ()) {
409
439
// remove the key from futureTimeouts_
410
- futureTimeouts_.erase (futureStartTime );
440
+ futureTimeouts_.erase (rpcEndTime );
411
441
}
412
442
}
443
+ // Not holding lock on markCompleted as this could run callbacks that
444
+ // call agent_->send
445
+ fm->markCompleted (std::move (message));
413
446
futureCV_.notify_all ();
414
447
} else {
415
448
// TODO: pass the error back to the caller instead of crashing here.
@@ -449,6 +482,98 @@ void ProcessGroupAgent::listenLoop() {
449
482
}
450
483
}
451
484
485
+ void ProcessGroupAgent::pollTimedOutRPCs () {
486
+ while (!shutdown_.load ()) {
487
+ std::chrono::milliseconds sleepTime;
488
+ std::unique_lock<std::mutex> lock{futureMutex_};
489
+ // Estimate amount of time the first future will time out in, and sleep
490
+ // for that long.
491
+ // if there are no futures or the first future's RPC timeout is set to 0
492
+ // (meaning no timeout), then sleep for a set "infinity" time.
493
+ if (futureTimeouts_.empty () ||
494
+ futureTimeouts_.begin ()->first == INFINITE_TIMEOUT) {
495
+ sleepTime = INFINITE_TIMEOUT;
496
+ } else {
497
+ const auto minFutureExpirationTime = futureTimeouts_.begin ()->first ;
498
+ const auto remainingTime = getRPCRemainingTime (minFutureExpirationTime);
499
+ sleepTime = std::max (remainingTime, std::chrono::milliseconds (0 ));
500
+ }
501
+
502
+ if (sleepTime == INFINITE_TIMEOUT) {
503
+ futureTimeoutCV_.wait (lock);
504
+ } else {
505
+ futureTimeoutCV_.wait_for (lock, sleepTime);
506
+ }
507
+
508
+ if (shutdown_.load ()) {
509
+ return ;
510
+ }
511
+
512
+ const auto timedOutFutures = processTimedOutFutures ();
513
+
514
+ // Do not hold the lock while marking futures completed, as markCompleted()
515
+ // could invoke callbacks.
516
+ lock.unlock ();
517
+ for (const auto & timedOutFuture : timedOutFutures) {
518
+ std::ostringstream ss;
519
+ ss << " RPC ran for more than " << timedOutFuture.timeout_ .count ()
520
+ << " milliseconds and timed out." ;
521
+ const auto exceptionMsg = createExceptionResponse (
522
+ Message ({}, {}, MessageType::EXCEPTION), ss.str ());
523
+ timedOutFuture.future_ ->markCompleted (exceptionMsg);
524
+
525
+ const int dst = timedOutFuture.dstRank_ ;
526
+ recvCounts_.increment (dst);
527
+ futureCV_.notify_all ();
528
+ }
529
+ }
530
+ }
531
+
532
+ const std::vector<ProcessGroupAgent::FutureInfo> ProcessGroupAgent::
533
+ processTimedOutFutures () {
534
+ std::vector<FutureInfo> timedOutFutures;
535
+ for (auto it = futureTimeouts_.begin (); it != futureTimeouts_.end ();
536
+ /* intentional no increment */ ) {
537
+ const auto & endTime = it->first ;
538
+ const auto remainingTime = getRPCRemainingTime (endTime);
539
+
540
+ if (remainingTime.count () > 0 ) {
541
+ // Since the futureTimeouts_ map is ordered by timeout, we don't need
542
+ // to check the remaining futures.
543
+ break ;
544
+ } else {
545
+ const std::vector<int64_t >& futureIDs = it->second ;
546
+ for (const auto & futureID : futureIDs) {
547
+ auto futureIt = futures_.find (futureID);
548
+ TORCH_INTERNAL_ASSERT (
549
+ futureIt != futures_.end (),
550
+ " Race Condition - Expected future does not exist in map" );
551
+ const auto futInfo = futureIt->second ;
552
+ timedOutFutures.push_back (futInfo);
553
+ futures_.erase (futureID);
554
+ }
555
+ it = futureTimeouts_.erase (it);
556
+ }
557
+ }
558
+ return timedOutFutures;
559
+ }
560
+
561
+ const std::chrono::milliseconds ProcessGroupAgent::getRPCRemainingTime (
562
+ const std::chrono::milliseconds& rpcEndTime) const {
563
+ const auto remainingTime =
564
+ rpcEndTime -
565
+ std::chrono::duration_cast<std::chrono::milliseconds>(
566
+ std::chrono::steady_clock::now ().time_since_epoch ());
567
+ return remainingTime;
568
+ }
569
+
570
+ const std::chrono::milliseconds ProcessGroupAgent::getRPCEndTime (
571
+ const FutureInfo& futureInfo) const {
572
+ return futureInfo.timeout_ == INFINITE_TIMEOUT
573
+ ? INFINITE_TIMEOUT
574
+ : futureInfo.startTime_ + futureInfo.timeout_ ;
575
+ }
576
+
452
577
} // namespace rpc
453
578
} // namespace distributed
454
579
} // namespace torch
0 commit comments