diff --git a/src/backend/distributed/transaction/transaction_recovery.c b/src/backend/distributed/transaction/transaction_recovery.c index f25823b3064..4a1503b55d1 100644 --- a/src/backend/distributed/transaction/transaction_recovery.c +++ b/src/backend/distributed/transaction/transaction_recovery.c @@ -53,7 +53,8 @@ PG_FUNCTION_INFO_V1(recover_prepared_transactions); /* Local functions forward declarations */ -static int RecoverWorkerTransactions(WorkerNode *workerNode); +static int RecoverWorkerTransactions(WorkerNode *workerNode, + MultiConnection *connection); static List * PendingWorkerTransactionList(MultiConnection *connection); static bool IsTransactionInProgress(HTAB *activeTransactionNumberSet, char *preparedTransactionName); @@ -123,10 +124,51 @@ RecoverTwoPhaseCommits(void) LockTransactionRecovery(ShareUpdateExclusiveLock); List *workerList = ActivePrimaryNodeList(NoLock); + List *workerConnections = NIL; WorkerNode *workerNode = NULL; + MultiConnection *connection = NULL; + + /* + * Pre-establish all connections to worker nodes. + * + * We do this to enforce a consistent lock acquisition order and prevent deadlocks. + * Currently, during extension updates, we take strong locks on the Citus + * catalog tables in a specific order: first on pg_dist_authinfo, then on + * pg_dist_transaction. It's critical that any operation locking these two + * tables adheres to this order, or a deadlock could occur. + * + * Note that RecoverWorkerTransactions() retains its lock until the end + * of the transaction, while GetNodeConnection() releases its lock after + * the catalog lookup. So when there are multiple workers in the active primary + * node list, the lock acquisition order may reverse in subsequent iterations + * of the loop calling RecoverWorkerTransactions(), increasing the risk + * of deadlock. + * + * By establishing all worker connections upfront, we ensure that + * RecoverWorkerTransactions() deals with a single distributed catalog table, + * thereby preventing deadlocks regardless of the lock acquisition sequence + * used in the upgrade extension script. + */ + foreach_declared_ptr(workerNode, workerList) { - recoveredTransactionCount += RecoverWorkerTransactions(workerNode); + int connectionFlags = 0; + char *nodeName = workerNode->workerName; + int nodePort = workerNode->workerPort; + + connection = GetNodeConnection(connectionFlags, nodeName, nodePort); + Assert(connection != NULL); + + /* + * We don't verify connection validity here. + * Instead, RecoverWorkerTransactions() performs the necessary + * sanity checks on the connection state. + */ + workerConnections = lappend(workerConnections, connection); + } + forboth_ptr(workerNode, workerList, connection, workerConnections) + { + recoveredTransactionCount += RecoverWorkerTransactions(workerNode, connection); } return recoveredTransactionCount; @@ -138,7 +180,7 @@ RecoverTwoPhaseCommits(void) * started by this node on the specified worker. */ static int -RecoverWorkerTransactions(WorkerNode *workerNode) +RecoverWorkerTransactions(WorkerNode *workerNode, MultiConnection *connection) { int recoveredTransactionCount = 0; @@ -156,8 +198,7 @@ RecoverWorkerTransactions(WorkerNode *workerNode) bool recoveryFailed = false; - int connectionFlags = 0; - MultiConnection *connection = GetNodeConnection(connectionFlags, nodeName, nodePort); + Assert(connection != NULL); if (connection->pgConn == NULL || PQstatus(connection->pgConn) != CONNECTION_OK) { ereport(WARNING, (errmsg("transaction recovery cannot connect to %s:%d",