Skip to content

Commit

Permalink
Fix Deadlock with transaction recovery is possible during Citus upgra…
Browse files Browse the repository at this point in the history
…des (citusdata#7875)

Currently, RecoverWorkerTransactions() creates a new connection for each worker
node and then performs transaction recovery by reading and locking the
pg_dist_transaction catalog table until the end of the transaction.
When RecoverTwoPhaseCommits() calls RecoverWorkerTransactions() for each worker
node, the lock acquisition order between pg_dist_authinfo and
pg_dist_transaction can reverse on alternate iterations.
This reversal can lead to a deadlock if any concurrent process requires locks
on these catalog tables—a situation that has surfaced during the
Citus upgrade workflow.

To resolve this, we now pre-establish all worker node connections upfront.
This change ensures that RecoverWorkerTransactions() operates with a single,
consistent distributed catalog table connection, thereby always acquiring locks
on pg_dist_authinfo and pg_dist_transaction in the correct order and preventing
potential deadlocks during extension updates or similar operations.
  • Loading branch information
codeforall committed Feb 24, 2025
1 parent 3df7f26 commit ca8cb7a
Showing 1 changed file with 46 additions and 5 deletions.
51 changes: 46 additions & 5 deletions src/backend/distributed/transaction/transaction_recovery.c
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ PG_FUNCTION_INFO_V1(recover_prepared_transactions);


/* Local functions forward declarations */
static int RecoverWorkerTransactions(WorkerNode *workerNode);
static int RecoverWorkerTransactions(WorkerNode *workerNode,
MultiConnection *connection);
static List * PendingWorkerTransactionList(MultiConnection *connection);
static bool IsTransactionInProgress(HTAB *activeTransactionNumberSet,
char *preparedTransactionName);
Expand Down Expand Up @@ -123,10 +124,51 @@ RecoverTwoPhaseCommits(void)
LockTransactionRecovery(ShareUpdateExclusiveLock);

List *workerList = ActivePrimaryNodeList(NoLock);
List *workerConnections = NIL;
WorkerNode *workerNode = NULL;
MultiConnection *connection = NULL;

/*
* Pre-establish all connections to worker nodes.
*
* We do this to enforce a consistent lock acquisition order and prevent deadlocks.
* Currently, during extension updates, we take strong locks on the Citus
* catalog tables in a specific order: first on pg_dist_authinfo, then on
* pg_dist_transaction. It's critical that any operation locking these two
* tables adheres to this order, or a deadlock could occur.
*
* Note that RecoverWorkerTransactions() retains its lock until the end
* of the transaction, while GetNodeConnection() releases its lock after
* the catalog lookup. So when there are multiple workers in the active primary
* node list, the lock acquisition order may reverse in subsequent iterations
* of the loop calling RecoverWorkerTransactions(), increasing the risk
* of deadlock.
*
* By establishing all worker connections upfront, we ensure that
* RecoverWorkerTransactions() deals with a single distributed catalog table,
* thereby preventing deadlocks regardless of the lock acquisition sequence
* used in the upgrade extension script.
*/

foreach_declared_ptr(workerNode, workerList)
{
recoveredTransactionCount += RecoverWorkerTransactions(workerNode);
int connectionFlags = 0;
char *nodeName = workerNode->workerName;
int nodePort = workerNode->workerPort;

connection = GetNodeConnection(connectionFlags, nodeName, nodePort);
Assert(connection != NULL);

/*
* We don't verify connection validity here.
* Instead, RecoverWorkerTransactions() performs the necessary
* sanity checks on the connection state.
*/
workerConnections = lappend(workerConnections, connection);
}
forboth_ptr(workerNode, workerList, connection, workerConnections)
{
recoveredTransactionCount += RecoverWorkerTransactions(workerNode, connection);
}

return recoveredTransactionCount;
Expand All @@ -138,7 +180,7 @@ RecoverTwoPhaseCommits(void)
* started by this node on the specified worker.
*/
static int
RecoverWorkerTransactions(WorkerNode *workerNode)
RecoverWorkerTransactions(WorkerNode *workerNode, MultiConnection *connection)
{
int recoveredTransactionCount = 0;

Expand All @@ -156,8 +198,7 @@ RecoverWorkerTransactions(WorkerNode *workerNode)

bool recoveryFailed = false;

int connectionFlags = 0;
MultiConnection *connection = GetNodeConnection(connectionFlags, nodeName, nodePort);
Assert(connection != NULL);
if (connection->pgConn == NULL || PQstatus(connection->pgConn) != CONNECTION_OK)
{
ereport(WARNING, (errmsg("transaction recovery cannot connect to %s:%d",
Expand Down

0 comments on commit ca8cb7a

Please sign in to comment.