From b93360180ded35144ec73df49af75e8e53666de0 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Thu, 3 Apr 2025 15:30:01 -0700 Subject: [PATCH 1/3] bulk_write should be able to accept a generator --- pymongo/asynchronous/bulk.py | 37 ++++++++++++++++++++++-------- pymongo/asynchronous/collection.py | 12 ++++------ pymongo/common.py | 8 +++++++ pymongo/synchronous/bulk.py | 37 ++++++++++++++++++++++-------- pymongo/synchronous/collection.py | 12 ++++------ test/asynchronous/test_bulk.py | 15 ++++++++++++ test/test_bulk.py | 15 ++++++++++++ 7 files changed, 100 insertions(+), 36 deletions(-) diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index ac514db98f..b4b042a632 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -26,6 +26,7 @@ from typing import ( TYPE_CHECKING, Any, + Generator, Iterator, Mapping, Optional, @@ -72,7 +73,7 @@ from pymongo.write_concern import WriteConcern if TYPE_CHECKING: - from pymongo.asynchronous.collection import AsyncCollection + from pymongo.asynchronous.collection import AsyncCollection, _WriteOp from pymongo.asynchronous.mongo_client import AsyncMongoClient from pymongo.asynchronous.pool import AsyncConnection from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline @@ -214,28 +215,45 @@ def add_delete( self.is_retryable = False self.ops.append((_DELETE, cmd)) - def gen_ordered(self) -> Iterator[Optional[_Run]]: + def gen_ordered(self, requests) -> Iterator[Optional[_Run]]: """Generate batches of operations, batched by type of operation, in the order **provided**. """ run = None - for idx, (op_type, operation) in enumerate(self.ops): + for idx, request in enumerate(requests): + try: + request._add_to_bulk(self) + except AttributeError: + raise TypeError(f"{request!r} is not a valid request") from None + (op_type, operation) = self.ops[idx] if run is None: run = _Run(op_type) elif run.op_type != op_type: yield run run = _Run(op_type) run.add(idx, operation) + if run is None: + raise InvalidOperation("No operations to execute") yield run - def gen_unordered(self) -> Iterator[_Run]: + def gen_unordered(self, requests) -> Iterator[_Run]: """Generate batches of operations, batched by type of operation, in arbitrary order. """ operations = [_Run(_INSERT), _Run(_UPDATE), _Run(_DELETE)] - for idx, (op_type, operation) in enumerate(self.ops): + for idx, request in enumerate(requests): + try: + request._add_to_bulk(self) + except AttributeError: + raise TypeError(f"{request!r} is not a valid request") from None + (op_type, operation) = self.ops[idx] operations[op_type].add(idx, operation) - + if ( + len(operations[_INSERT].ops) == 0 + and len(operations[_UPDATE].ops) == 0 + and len(operations[_DELETE].ops) == 0 + ): + raise InvalidOperation("No operations to execute") for run in operations: if run.ops: yield run @@ -726,13 +744,12 @@ async def execute_no_results( async def execute( self, + generator: Generator[_WriteOp[_DocumentType]], write_concern: WriteConcern, session: Optional[AsyncClientSession], operation: str, ) -> Any: """Execute operations.""" - if not self.ops: - raise InvalidOperation("No operations to execute") if self.executed: raise InvalidOperation("Bulk operations can only be executed once.") self.executed = True @@ -740,9 +757,9 @@ async def execute( session = _validate_session_write_concern(session, write_concern) if self.ordered: - generator = self.gen_ordered() + generator = self.gen_ordered(generator) else: - generator = self.gen_unordered() + generator = self.gen_unordered(generator) client = self.collection.database.client if not write_concern.acknowledged: diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index 7fb20b7ab3..3286dea4d4 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -23,6 +23,7 @@ AsyncContextManager, Callable, Coroutine, + Generator, Generic, Iterable, Iterator, @@ -699,7 +700,7 @@ async def _create( @_csot.apply async def bulk_write( self, - requests: Sequence[_WriteOp[_DocumentType]], + requests: Sequence[_WriteOp[_DocumentType]] | Generator[_WriteOp[_DocumentType]], ordered: bool = True, bypass_document_validation: Optional[bool] = None, session: Optional[AsyncClientSession] = None, @@ -779,17 +780,12 @@ async def bulk_write( .. versionadded:: 3.0 """ - common.validate_list("requests", requests) + common.validate_list_or_generator("requests", requests) blk = _AsyncBulk(self, ordered, bypass_document_validation, comment=comment, let=let) - for request in requests: - try: - request._add_to_bulk(blk) - except AttributeError: - raise TypeError(f"{request!r} is not a valid request") from None write_concern = self._write_concern_for(session) - bulk_api_result = await blk.execute(write_concern, session, _Op.INSERT) + bulk_api_result = await blk.execute(requests, write_concern, session, _Op.INSERT) if bulk_api_result is not None: return BulkWriteResult(bulk_api_result, True) return BulkWriteResult({}, False) diff --git a/pymongo/common.py b/pymongo/common.py index 3d8095eedf..6d9bb2f37a 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -24,6 +24,7 @@ TYPE_CHECKING, Any, Callable, + Generator, Iterator, Mapping, MutableMapping, @@ -530,6 +531,13 @@ def validate_list(option: str, value: Any) -> list: return value +def validate_list_or_generator(option: str, value: Any) -> Union[list, Generator]: + """Validates that 'value' is a list or generator.""" + if isinstance(value, Generator): + return value + return validate_list(option, value) + + def validate_list_or_none(option: Any, value: Any) -> Optional[list]: """Validates that 'value' is a list or None.""" if value is None: diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index a528b09add..b92bb1a511 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -26,6 +26,7 @@ from typing import ( TYPE_CHECKING, Any, + Generator, Iterator, Mapping, Optional, @@ -72,7 +73,7 @@ from pymongo.write_concern import WriteConcern if TYPE_CHECKING: - from pymongo.synchronous.collection import Collection + from pymongo.synchronous.collection import Collection, _WriteOp from pymongo.synchronous.mongo_client import MongoClient from pymongo.synchronous.pool import Connection from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline @@ -214,28 +215,45 @@ def add_delete( self.is_retryable = False self.ops.append((_DELETE, cmd)) - def gen_ordered(self) -> Iterator[Optional[_Run]]: + def gen_ordered(self, requests) -> Iterator[Optional[_Run]]: """Generate batches of operations, batched by type of operation, in the order **provided**. """ run = None - for idx, (op_type, operation) in enumerate(self.ops): + for idx, request in enumerate(requests): + try: + request._add_to_bulk(self) + except AttributeError: + raise TypeError(f"{request!r} is not a valid request") from None + (op_type, operation) = self.ops[idx] if run is None: run = _Run(op_type) elif run.op_type != op_type: yield run run = _Run(op_type) run.add(idx, operation) + if run is None: + raise InvalidOperation("No operations to execute") yield run - def gen_unordered(self) -> Iterator[_Run]: + def gen_unordered(self, requests) -> Iterator[_Run]: """Generate batches of operations, batched by type of operation, in arbitrary order. """ operations = [_Run(_INSERT), _Run(_UPDATE), _Run(_DELETE)] - for idx, (op_type, operation) in enumerate(self.ops): + for idx, request in enumerate(requests): + try: + request._add_to_bulk(self) + except AttributeError: + raise TypeError(f"{request!r} is not a valid request") from None + (op_type, operation) = self.ops[idx] operations[op_type].add(idx, operation) - + if ( + len(operations[_INSERT].ops) == 0 + and len(operations[_UPDATE].ops) == 0 + and len(operations[_DELETE].ops) == 0 + ): + raise InvalidOperation("No operations to execute") for run in operations: if run.ops: yield run @@ -724,13 +742,12 @@ def execute_no_results( def execute( self, + generator: Generator[_WriteOp[_DocumentType]], write_concern: WriteConcern, session: Optional[ClientSession], operation: str, ) -> Any: """Execute operations.""" - if not self.ops: - raise InvalidOperation("No operations to execute") if self.executed: raise InvalidOperation("Bulk operations can only be executed once.") self.executed = True @@ -738,9 +755,9 @@ def execute( session = _validate_session_write_concern(session, write_concern) if self.ordered: - generator = self.gen_ordered() + generator = self.gen_ordered(generator) else: - generator = self.gen_unordered() + generator = self.gen_unordered(generator) client = self.collection.database.client if not write_concern.acknowledged: diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index 8a71768318..52c25de744 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -22,6 +22,7 @@ Any, Callable, ContextManager, + Generator, Generic, Iterable, Iterator, @@ -698,7 +699,7 @@ def _create( @_csot.apply def bulk_write( self, - requests: Sequence[_WriteOp[_DocumentType]], + requests: Sequence[_WriteOp[_DocumentType]] | Generator[_WriteOp[_DocumentType]], ordered: bool = True, bypass_document_validation: Optional[bool] = None, session: Optional[ClientSession] = None, @@ -778,17 +779,12 @@ def bulk_write( .. versionadded:: 3.0 """ - common.validate_list("requests", requests) + common.validate_list_or_generator("requests", requests) blk = _Bulk(self, ordered, bypass_document_validation, comment=comment, let=let) - for request in requests: - try: - request._add_to_bulk(blk) - except AttributeError: - raise TypeError(f"{request!r} is not a valid request") from None write_concern = self._write_concern_for(session) - bulk_api_result = blk.execute(write_concern, session, _Op.INSERT) + bulk_api_result = blk.execute(requests, write_concern, session, _Op.INSERT) if bulk_api_result is not None: return BulkWriteResult(bulk_api_result, True) return BulkWriteResult({}, False) diff --git a/test/asynchronous/test_bulk.py b/test/asynchronous/test_bulk.py index 65ed6e236a..3becea0777 100644 --- a/test/asynchronous/test_bulk.py +++ b/test/asynchronous/test_bulk.py @@ -299,6 +299,21 @@ async def test_numerous_inserts(self): self.assertEqual(n_docs, result.inserted_count) self.assertEqual(n_docs, await self.coll.count_documents({})) + async def test_numerous_inserts_generator(self): + # Ensure we don't exceed server's maxWriteBatchSize size limit. + n_docs = await async_client_context.max_write_batch_size + 100 + requests = (InsertOne[dict]({}) for _ in range(n_docs)) + result = await self.coll.bulk_write(requests, ordered=False) + self.assertEqual(n_docs, result.inserted_count) + self.assertEqual(n_docs, await self.coll.count_documents({})) + + # Same with ordered bulk. + await self.coll.drop() + requests = (InsertOne[dict]({}) for _ in range(n_docs)) + result = await self.coll.bulk_write(requests) + self.assertEqual(n_docs, result.inserted_count) + self.assertEqual(n_docs, await self.coll.count_documents({})) + async def test_bulk_max_message_size(self): await self.coll.delete_many({}) self.addAsyncCleanup(self.coll.delete_many, {}) diff --git a/test/test_bulk.py b/test/test_bulk.py index 8a863cc49b..3e631e661f 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -299,6 +299,21 @@ def test_numerous_inserts(self): self.assertEqual(n_docs, result.inserted_count) self.assertEqual(n_docs, self.coll.count_documents({})) + def test_numerous_inserts_generator(self): + # Ensure we don't exceed server's maxWriteBatchSize size limit. + n_docs = client_context.max_write_batch_size + 100 + requests = (InsertOne[dict]({}) for _ in range(n_docs)) + result = self.coll.bulk_write(requests, ordered=False) + self.assertEqual(n_docs, result.inserted_count) + self.assertEqual(n_docs, self.coll.count_documents({})) + + # Same with ordered bulk. + self.coll.drop() + requests = (InsertOne[dict]({}) for _ in range(n_docs)) + result = self.coll.bulk_write(requests) + self.assertEqual(n_docs, result.inserted_count) + self.assertEqual(n_docs, self.coll.count_documents({})) + def test_bulk_max_message_size(self): self.coll.delete_many({}) self.addCleanup(self.coll.delete_many, {}) From 0648bcfafc0c81d11af5df7d17c2d213fb733e43 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Mon, 21 Apr 2025 15:00:48 -0700 Subject: [PATCH 2/3] wip --- pymongo/asynchronous/bulk.py | 120 +++++++++++++++--------- pymongo/asynchronous/client_bulk.py | 13 +-- pymongo/asynchronous/client_session.py | 7 +- pymongo/asynchronous/collection.py | 58 +++++------- pymongo/asynchronous/mongo_client.py | 61 ++++++++----- pymongo/bulk_shared.py | 3 + pymongo/operations.py | 24 ++--- pymongo/synchronous/bulk.py | 122 ++++++++++++++++--------- pymongo/synchronous/client_bulk.py | 13 +-- pymongo/synchronous/client_session.py | 5 +- pymongo/synchronous/collection.py | 64 +++++-------- pymongo/synchronous/mongo_client.py | 59 +++++++----- test/asynchronous/test_bulk.py | 5 - test/test_bulk.py | 5 - 14 files changed, 307 insertions(+), 252 deletions(-) diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index b4b042a632..630be7c25e 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -26,7 +26,8 @@ from typing import ( TYPE_CHECKING, Any, - Generator, + Callable, + Iterable, Iterator, Mapping, Optional, @@ -111,9 +112,6 @@ def __init__( self.uses_hint_update = False self.uses_hint_delete = False self.uses_sort = False - self.is_retryable = True - self.retrying = False - self.started_retryable_write = False # Extra state so that we know where to pick up on a retry attempt. self.current_run = None self.next_run = None @@ -129,13 +127,32 @@ def bulk_ctx_class(self) -> Type[_BulkWriteContext]: self.is_encrypted = False return _BulkWriteContext - def add_insert(self, document: _DocumentOut) -> None: + @property + def is_retryable(self) -> bool: + if self.current_run: + return self.current_run.is_retryable + return True + + @property + def retrying(self) -> bool: + if self.current_run: + return self.current_run.retrying + return False + + @property + def started_retryable_write(self) -> bool: + if self.current_run: + return self.current_run.started_retryable_write + return False + + def add_insert(self, document: _DocumentOut) -> bool: """Add an insert document to the list of ops.""" validate_is_document_type("document", document) # Generate ObjectId client side. if not (isinstance(document, RawBSONDocument) or "_id" in document): document["_id"] = ObjectId() self.ops.append((_INSERT, document)) + return True def add_update( self, @@ -147,7 +164,7 @@ def add_update( array_filters: Optional[list[Mapping[str, Any]]] = None, hint: Union[str, dict[str, Any], None] = None, sort: Optional[Mapping[str, Any]] = None, - ) -> None: + ) -> bool: """Create an update document and add it to the list of ops.""" validate_ok_for_update(update) cmd: dict[str, Any] = {"q": selector, "u": update, "multi": multi} @@ -165,10 +182,12 @@ def add_update( if sort is not None: self.uses_sort = True cmd["sort"] = sort + + self.ops.append((_UPDATE, cmd)) if multi: # A bulk_write containing an update_many is not retryable. - self.is_retryable = False - self.ops.append((_UPDATE, cmd)) + return False + return True def add_replace( self, @@ -178,7 +197,7 @@ def add_replace( collation: Optional[Mapping[str, Any]] = None, hint: Union[str, dict[str, Any], None] = None, sort: Optional[Mapping[str, Any]] = None, - ) -> None: + ) -> bool: """Create a replace document and add it to the list of ops.""" validate_ok_for_replace(replacement) cmd: dict[str, Any] = {"q": selector, "u": replacement} @@ -194,6 +213,7 @@ def add_replace( self.uses_sort = True cmd["sort"] = sort self.ops.append((_UPDATE, cmd)) + return True def add_delete( self, @@ -201,7 +221,7 @@ def add_delete( limit: int, collation: Optional[Mapping[str, Any]] = None, hint: Union[str, dict[str, Any], None] = None, - ) -> None: + ) -> bool: """Create a delete document and add it to the list of ops.""" cmd: dict[str, Any] = {"q": selector, "limit": limit} if collation is not None: @@ -210,21 +230,24 @@ def add_delete( if hint is not None: self.uses_hint_delete = True cmd["hint"] = hint + + self.ops.append((_DELETE, cmd)) if limit == _DELETE_ALL: # A bulk_write containing a delete_many is not retryable. - self.is_retryable = False - self.ops.append((_DELETE, cmd)) + return False + return True - def gen_ordered(self, requests) -> Iterator[Optional[_Run]]: + def gen_ordered( + self, + requests: Iterable[Any], + process: Callable[[Union[_DocumentType, RawBSONDocument, _WriteOp]], bool], + ) -> Iterator[_Run]: """Generate batches of operations, batched by type of operation, in the order **provided**. """ run = None for idx, request in enumerate(requests): - try: - request._add_to_bulk(self) - except AttributeError: - raise TypeError(f"{request!r} is not a valid request") from None + retryable = process(request) (op_type, operation) = self.ops[idx] if run is None: run = _Run(op_type) @@ -232,22 +255,25 @@ def gen_ordered(self, requests) -> Iterator[Optional[_Run]]: yield run run = _Run(op_type) run.add(idx, operation) + run.is_retryable = run.is_retryable and retryable if run is None: raise InvalidOperation("No operations to execute") yield run - def gen_unordered(self, requests) -> Iterator[_Run]: + def gen_unordered( + self, + requests: Iterable[Any], + process: Callable[[Union[_DocumentType, RawBSONDocument, _WriteOp]], bool], + ) -> Iterator[_Run]: """Generate batches of operations, batched by type of operation, in arbitrary order. """ operations = [_Run(_INSERT), _Run(_UPDATE), _Run(_DELETE)] for idx, request in enumerate(requests): - try: - request._add_to_bulk(self) - except AttributeError: - raise TypeError(f"{request!r} is not a valid request") from None + retryable = process(request) (op_type, operation) = self.ops[idx] operations[op_type].add(idx, operation) + operations[op_type].is_retryable = operations[op_type].is_retryable and retryable if ( len(operations[_INSERT].ops) == 0 and len(operations[_UPDATE].ops) == 0 @@ -488,8 +514,8 @@ async def _execute_command( session: Optional[AsyncClientSession], conn: AsyncConnection, op_id: int, - retryable: bool, full_result: MutableMapping[str, Any], + validate: bool, final_write_concern: Optional[WriteConcern] = None, ) -> None: db_name = self.collection.database.name @@ -507,7 +533,7 @@ async def _execute_command( last_run = False while run: - if not self.retrying: + if not run.retrying: self.next_run = next(generator, None) if self.next_run is None: last_run = True @@ -541,10 +567,10 @@ async def _execute_command( if session: # Start a new retryable write unless one was already # started for this command. - if retryable and not self.started_retryable_write: + if run.is_retryable and not run.started_retryable_write: session._start_retryable_write() self.started_retryable_write = True - session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn) + session._apply_to(cmd, run.is_retryable, ReadPreference.PRIMARY, conn) conn.send_cluster_time(cmd, session, client) conn.add_server_api(cmd) # CSOT: apply timeout before encoding the command. @@ -552,9 +578,10 @@ async def _execute_command( ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible in one command. + if validate: + await self.validate_batch(conn, write_concern) if write_concern.acknowledged: result, to_send = await self._execute_batch(bwc, cmd, ops, client) - # Retryable writeConcernErrors halt the execution of this run. wce = result.get("writeConcernError", {}) if wce.get("code", 0) in _RETRYABLE_ERROR_CODES: @@ -567,8 +594,8 @@ async def _execute_command( _merge_command(run, full_result, run.idx_offset, result) # We're no longer in a retry once a command succeeds. - self.retrying = False - self.started_retryable_write = False + run.retrying = False + run.started_retryable_write = False if self.ordered and "writeErrors" in result: break @@ -606,7 +633,8 @@ async def execute_command( op_id = _randint() async def retryable_bulk( - session: Optional[AsyncClientSession], conn: AsyncConnection, retryable: bool + session: Optional[AsyncClientSession], + conn: AsyncConnection, ) -> None: await self._execute_command( generator, @@ -614,26 +642,24 @@ async def retryable_bulk( session, conn, op_id, - retryable, full_result, + validate=False, ) client = self.collection.database.client _ = await client._retryable_write( - self.is_retryable, retryable_bulk, session, operation, bulk=self, # type: ignore[arg-type] operation_id=op_id, ) - if full_result["writeErrors"] or full_result["writeConcernErrors"]: _raise_bulk_write_error(full_result) return full_result async def execute_op_msg_no_results( - self, conn: AsyncConnection, generator: Iterator[Any] + self, conn: AsyncConnection, generator: Iterator[Any], write_concern: WriteConcern ) -> None: """Execute write commands with OP_MSG and w=0 writeConcern, unordered.""" db_name = self.collection.database.name @@ -667,6 +693,7 @@ async def execute_op_msg_no_results( conn.add_server_api(cmd) ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible. + await self.validate_batch(conn, write_concern) to_send = await self._execute_batch_unack(bwc, cmd, ops, client) run.idx_offset += len(to_send) self.current_run = run = next(generator, None) @@ -700,12 +727,15 @@ async def execute_command_no_results( None, conn, op_id, - False, full_result, + True, write_concern, ) - except OperationFailure: - pass + except OperationFailure as exc: + if "Cannot set bypass_document_validation with unacknowledged write concern" in str( + exc + ): + raise exc async def execute_no_results( self, @@ -714,6 +744,11 @@ async def execute_no_results( write_concern: WriteConcern, ) -> None: """Execute all operations, returning no results (w=0).""" + if self.ordered: + return await self.execute_command_no_results(conn, generator, write_concern) + return await self.execute_op_msg_no_results(conn, generator, write_concern) + + async def validate_batch(self, conn: AsyncConnection, write_concern: WriteConcern) -> None: if self.uses_collation: raise ConfigurationError("Collation is unsupported for unacknowledged writes.") if self.uses_array_filters: @@ -738,13 +773,10 @@ async def execute_no_results( "Cannot set bypass_document_validation with unacknowledged write concern" ) - if self.ordered: - return await self.execute_command_no_results(conn, generator, write_concern) - return await self.execute_op_msg_no_results(conn, generator) - async def execute( self, - generator: Generator[_WriteOp[_DocumentType]], + generator: Iterable[Any], + process: Callable[[Union[_DocumentType, RawBSONDocument, _WriteOp]], bool], write_concern: WriteConcern, session: Optional[AsyncClientSession], operation: str, @@ -757,9 +789,9 @@ async def execute( session = _validate_session_write_concern(session, write_concern) if self.ordered: - generator = self.gen_ordered(generator) + generator = self.gen_ordered(generator, process) else: - generator = self.gen_unordered(generator) + generator = self.gen_unordered(generator, process) client = self.collection.database.client if not write_concern.acknowledged: diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index 5f7ac013e9..dbbad9e0e8 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -116,6 +116,7 @@ def __init__( self.is_retryable = self.client.options.retry_writes self.retrying = False self.started_retryable_write = False + self.current_run = None @property def bulk_ctx_class(self) -> Type[_ClientBulkWriteContext]: @@ -488,7 +489,6 @@ async def _execute_command( session: Optional[AsyncClientSession], conn: AsyncConnection, op_id: int, - retryable: bool, full_result: MutableMapping[str, Any], final_write_concern: Optional[WriteConcern] = None, ) -> None: @@ -534,10 +534,10 @@ async def _execute_command( if session: # Start a new retryable write unless one was already # started for this command. - if retryable and not self.started_retryable_write: + if self.is_retryable and not self.started_retryable_write: session._start_retryable_write() self.started_retryable_write = True - session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn) + session._apply_to(cmd, self.is_retryable, ReadPreference.PRIMARY, conn) conn.send_cluster_time(cmd, session, self.client) conn.add_server_api(cmd) # CSOT: apply timeout before encoding the command. @@ -564,7 +564,7 @@ async def _execute_command( # Synthesize the full bulk result without modifying the # current one because this write operation may be retried. - if retryable and (retryable_top_level_error or retryable_network_error): + if self.is_retryable and (retryable_top_level_error or retryable_network_error): full = copy.deepcopy(full_result) _merge_command(self.ops, self.idx_offset, full, result) _throw_client_bulk_write_exception(full, self.verbose_results) @@ -583,7 +583,7 @@ async def _execute_command( _merge_command(self.ops, self.idx_offset, full_result, result) break - if retryable: + if self.is_retryable: # Retryable writeConcernErrors halt the execution of this batch. wce = result.get("writeConcernError", {}) if wce.get("code", 0) in _RETRYABLE_ERROR_CODES: @@ -638,7 +638,6 @@ async def execute_command( async def retryable_bulk( session: Optional[AsyncClientSession], conn: AsyncConnection, - retryable: bool, ) -> None: if conn.max_wire_version < 25: raise InvalidOperation( @@ -649,12 +648,10 @@ async def retryable_bulk( session, conn, op_id, - retryable, full_result, ) await self.client._retryable_write( - self.is_retryable, retryable_bulk, session, operation, diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index b808684dd4..b9d8449a34 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -854,13 +854,12 @@ async def _finish_transaction_with_retry(self, command_name: str) -> dict[str, A """ async def func( - _session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable: bool + _session: Optional[AsyncClientSession], + conn: AsyncConnection, ) -> dict[str, Any]: return await self._finish_transaction(conn, command_name) - return await self._client._retry_internal( - func, self, None, retryable=True, operation=_Op.ABORT - ) + return await self._client._retry_internal(func, self, None, operation=_Op.ABORT) async def _finish_transaction(self, conn: AsyncConnection, command_name: str) -> dict[str, Any]: self._transaction.attempt += 1 diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index 3286dea4d4..5ee67ddf89 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -23,7 +23,6 @@ AsyncContextManager, Callable, Coroutine, - Generator, Generic, Iterable, Iterator, @@ -700,7 +699,7 @@ async def _create( @_csot.apply async def bulk_write( self, - requests: Sequence[_WriteOp[_DocumentType]] | Generator[_WriteOp[_DocumentType]], + requests: Iterable[_WriteOp], ordered: bool = True, bypass_document_validation: Optional[bool] = None, session: Optional[AsyncClientSession] = None, @@ -785,7 +784,16 @@ async def bulk_write( blk = _AsyncBulk(self, ordered, bypass_document_validation, comment=comment, let=let) write_concern = self._write_concern_for(session) - bulk_api_result = await blk.execute(requests, write_concern, session, _Op.INSERT) + + def process_for_bulk(request: _WriteOp) -> bool: + try: + return request._add_to_bulk(blk) + except AttributeError: + raise TypeError(f"{request!r} is not a valid request") from None + + bulk_api_result = await blk.execute( + requests, process_for_bulk, write_concern, session, _Op.INSERT + ) if bulk_api_result is not None: return BulkWriteResult(bulk_api_result, True) return BulkWriteResult({}, False) @@ -802,17 +810,15 @@ async def _insert_one( ) -> Any: """Internal helper for inserting a single document.""" write_concern = write_concern or self.write_concern - acknowledged = write_concern.acknowledged command = {"insert": self.name, "ordered": ordered, "documents": [doc]} if comment is not None: command["comment"] = comment async def _insert_command( - session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool + session: Optional[AsyncClientSession], conn: AsyncConnection ) -> None: if bypass_doc_val is not None: command["bypassDocumentValidation"] = bypass_doc_val - result = await conn.command( self._database.name, command, @@ -820,14 +826,11 @@ async def _insert_command( codec_options=self._write_response_codec_options, session=session, client=self._database.client, - retryable_write=retryable_write, ) _check_write_command_response(result) - await self._database.client._retryable_write( - acknowledged, _insert_command, session, operation=_Op.INSERT - ) + await self._database.client._retryable_write(_insert_command, session, operation=_Op.INSERT) if not isinstance(doc, RawBSONDocument): return doc.get("_id") @@ -956,20 +959,19 @@ async def insert_many( raise TypeError("documents must be a non-empty list") inserted_ids: list[ObjectId] = [] - def gen() -> Iterator[tuple[int, Mapping[str, Any]]]: + def process_for_bulk(document: Union[_DocumentType, RawBSONDocument]) -> bool: """A generator that validates documents and handles _ids.""" - for document in documents: - common.validate_is_document_type("document", document) - if not isinstance(document, RawBSONDocument): - if "_id" not in document: - document["_id"] = ObjectId() # type: ignore[index] - inserted_ids.append(document["_id"]) - yield (message._INSERT, document) + common.validate_is_document_type("document", document) + if not isinstance(document, RawBSONDocument): + if "_id" not in document: + document["_id"] = ObjectId() # type: ignore[index] + inserted_ids.append(document["_id"]) + blk.ops.append((message._INSERT, document)) + return True write_concern = self._write_concern_for(session) blk = _AsyncBulk(self, ordered, bypass_document_validation, comment=comment) - blk.ops = list(gen()) - await blk.execute(write_concern, session, _Op.INSERT) + await blk.execute(documents, process_for_bulk, write_concern, session, _Op.INSERT) return InsertManyResult(inserted_ids, write_concern.acknowledged) async def _update( @@ -987,7 +989,6 @@ async def _update( array_filters: Optional[Sequence[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None, session: Optional[AsyncClientSession] = None, - retryable_write: bool = False, let: Optional[Mapping[str, Any]] = None, sort: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, @@ -1050,7 +1051,6 @@ async def _update( codec_options=self._write_response_codec_options, session=session, client=self._database.client, - retryable_write=retryable_write, ) ).copy() _check_write_command_response(result) @@ -1090,7 +1090,7 @@ async def _update_retryable( """Internal update / replace helper.""" async def _update( - session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool + session: Optional[AsyncClientSession], conn: AsyncConnection ) -> Optional[Mapping[str, Any]]: return await self._update( conn, @@ -1106,14 +1106,12 @@ async def _update( array_filters=array_filters, hint=hint, session=session, - retryable_write=retryable_write, let=let, sort=sort, comment=comment, ) return await self._database.client._retryable_write( - (write_concern or self.write_concern).acknowledged and not multi, _update, session, operation, @@ -1503,7 +1501,6 @@ async def _delete( collation: Optional[_CollationIn] = None, hint: Optional[_IndexKeyHint] = None, session: Optional[AsyncClientSession] = None, - retryable_write: bool = False, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, ) -> Mapping[str, Any]: @@ -1543,7 +1540,6 @@ async def _delete( codec_options=self._write_response_codec_options, session=session, client=self._database.client, - retryable_write=retryable_write, ) _check_write_command_response(result) return result @@ -1564,7 +1560,7 @@ async def _delete_retryable( """Internal delete helper.""" async def _delete( - session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool + session: Optional[AsyncClientSession], conn: AsyncConnection ) -> Mapping[str, Any]: return await self._delete( conn, @@ -1576,13 +1572,11 @@ async def _delete( collation=collation, hint=hint, session=session, - retryable_write=retryable_write, let=let, comment=comment, ) return await self._database.client._retryable_write( - (write_concern or self.write_concern).acknowledged and not multi, _delete, session, operation=_Op.DELETE, @@ -3227,7 +3221,7 @@ async def _find_and_modify( write_concern = self._write_concern_for_cmd(cmd, session) async def _find_and_modify_helper( - session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool + session: Optional[AsyncClientSession], conn: AsyncConnection ) -> Any: acknowledged = write_concern.acknowledged if array_filters is not None: @@ -3253,7 +3247,6 @@ async def _find_and_modify_helper( write_concern=write_concern, collation=collation, session=session, - retryable_write=retryable_write, user_fields=_FIND_AND_MODIFY_DOC_FIELDS, ) _check_write_command_response(out) @@ -3261,7 +3254,6 @@ async def _find_and_modify_helper( return out.get("value") return await self._database.client._retryable_write( - write_concern.acknowledged, _find_and_modify_helper, session, operation=_Op.FIND_AND_MODIFY, diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 16753420c0..4c8230cea9 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -142,9 +142,7 @@ T = TypeVar("T") -_WriteCall = Callable[ - [Optional["AsyncClientSession"], "AsyncConnection", bool], Coroutine[Any, Any, T] -] +_WriteCall = Callable[[Optional["AsyncClientSession"], "AsyncConnection"], Coroutine[Any, Any, T]] _ReadCall = Callable[ [Optional["AsyncClientSession"], "Server", "AsyncConnection", _ServerMode], Coroutine[Any, Any, T], @@ -1894,7 +1892,6 @@ async def _cmd( async def _retry_with_session( self, - retryable: bool, func: _WriteCall[T], session: Optional[AsyncClientSession], bulk: Optional[Union[_AsyncBulk, _AsyncClientBulk]], @@ -1910,15 +1907,11 @@ async def _retry_with_session( """ # Ensure that the options supports retry_writes and there is a valid session not in # transaction, otherwise, we will not support retry behavior for this txn. - retryable = bool( - retryable and self.options.retry_writes and session and not session.in_transaction - ) return await self._retry_internal( func=func, session=session, bulk=bulk, operation=operation, - retryable=retryable, operation_id=operation_id, ) @@ -1932,7 +1925,6 @@ async def _retry_internal( is_read: bool = False, address: Optional[_Address] = None, read_pref: Optional[_ServerMode] = None, - retryable: bool = False, operation_id: Optional[int] = None, ) -> T: """Internal retryable helper for all client transactions. @@ -1957,7 +1949,6 @@ async def _retry_internal( session=session, read_pref=read_pref, address=address, - retryable=retryable, operation_id=operation_id, ).run() @@ -2000,13 +1991,11 @@ async def _retryable_read( is_read=True, address=address, read_pref=read_pref, - retryable=retryable, operation_id=operation_id, ) async def _retryable_write( self, - retryable: bool, func: _WriteCall[T], session: Optional[AsyncClientSession], operation: str, @@ -2027,7 +2016,7 @@ async def _retryable_write( :param bulk: bulk abstraction to execute operations in bulk, defaults to None """ async with self._tmp_session(session) as s: - return await self._retry_with_session(retryable, func, s, bulk, operation, operation_id) + return await self._retry_with_session(func, s, bulk, operation, operation_id) def _cleanup_cursor_no_lock( self, @@ -2662,7 +2651,6 @@ def __init__( session: Optional[AsyncClientSession] = None, read_pref: Optional[_ServerMode] = None, address: Optional[_Address] = None, - retryable: bool = False, operation_id: Optional[int] = None, ): self._last_error: Optional[Exception] = None @@ -2674,7 +2662,7 @@ def __init__( self._bulk = bulk self._session = session self._is_read = is_read - self._retryable = retryable + self._retryable = True self._read_pref = read_pref self._server_selector: Callable[[Selection], Selection] = ( read_pref if is_read else writable_server_selector # type: ignore @@ -2685,6 +2673,11 @@ def __init__( self._operation = operation self._operation_id = operation_id + def _bulk_retryable(self) -> bool: + if self._bulk is not None and self._bulk.current_run is not None: + return self._bulk.current_run.is_retryable + return True + async def run(self) -> T: """Runs the supplied func() and attempts a retry @@ -2695,10 +2688,15 @@ async def run(self) -> T: # Increment the transaction id up front to ensure any retry attempt # will use the proper txnNumber, even if server or socket selection # fails before the command can be sent. - if self._is_session_state_retryable() and self._retryable and not self._is_read: + if ( + self._is_session_state_retryable() + and self._retryable + and self._bulk_retryable() + and not self._is_read + ): self._session._start_retryable_write() # type: ignore - if self._bulk: - self._bulk.started_retryable_write = True + if self._bulk and self._bulk.current_run: + self._bulk.current_run.started_retryable_write = True while True: self._check_last_error(check_csot=True) @@ -2731,7 +2729,7 @@ async def run(self) -> T: # Specialized catch on write operation if not self._is_read: - if not self._retryable: + if not self._retryable and not self._bulk_retryable(): raise if isinstance(exc, ClientBulkWriteException) and exc.error: retryable_write_error_exc = isinstance( @@ -2748,7 +2746,10 @@ async def run(self) -> T: else: raise if self._bulk: - self._bulk.retrying = True + if self._bulk.current_run: + self._bulk.current_run.retrying = True + else: + self._bulk.retrying = True else: self._retrying = True if not exc.has_error_label("NoWritesPerformed"): @@ -2761,11 +2762,19 @@ async def run(self) -> T: def _is_not_eligible_for_retry(self) -> bool: """Checks if the exchange is not eligible for retry""" - return not self._retryable or (self._is_retrying() and not self._multiple_retries) + return ( + not self._retryable + or not self._bulk_retryable() + or (self._is_retrying() and not self._multiple_retries) + ) def _is_retrying(self) -> bool: """Checks if the exchange is currently undergoing a retry""" - return self._bulk.retrying if self._bulk else self._retrying + return ( + self._bulk.current_run.retrying + if self._bulk is not None and self._bulk.current_run is not None + else self._retrying + ) def _is_session_state_retryable(self) -> bool: """Checks if provided session is eligible for retry @@ -2825,9 +2834,11 @@ async def _write(self) -> T: # not support sessions raise the last error. self._check_last_error() self._retryable = False - return await self._func(self._session, conn, self._retryable) # type: ignore + if self._bulk and self._bulk.current_run: + self._bulk.current_run.is_retryable = False + return await self._func(self._session, conn) # type: ignore except PyMongoError as exc: - if not self._retryable: + if not self._retryable or not self._bulk_retryable(): raise # Add the RetryableWriteError label, if applicable. _add_retryable_write_error(exc, max_wire_version, is_mongos) @@ -2844,7 +2855,7 @@ async def _read(self) -> T: conn, read_pref, ): - if self._retrying and not self._retryable: + if self._retrying and (not self._retryable or not self._bulk_retryable()): self._check_last_error() return await self._func(self._session, self._server, conn, read_pref) # type: ignore diff --git a/pymongo/bulk_shared.py b/pymongo/bulk_shared.py index 9276419d8a..b157edd2e2 100644 --- a/pymongo/bulk_shared.py +++ b/pymongo/bulk_shared.py @@ -50,6 +50,9 @@ def __init__(self, op_type: int) -> None: self.index_map: list[int] = [] self.ops: list[Any] = [] self.idx_offset: int = 0 + self.is_retryable = True + self.retrying = False + self.started_retryable_write = False def index(self, idx: int) -> int: """Get the original index of an operation in this run. diff --git a/pymongo/operations.py b/pymongo/operations.py index 300f1ba123..49b41ee614 100644 --- a/pymongo/operations.py +++ b/pymongo/operations.py @@ -106,9 +106,9 @@ def __init__(self, document: _DocumentType, namespace: Optional[str] = None) -> self._doc = document self._namespace = namespace - def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> bool: """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" - bulkobj.add_insert(self._doc) # type: ignore[arg-type] + return bulkobj.add_insert(self._doc) # type: ignore[arg-type] def _add_to_client_bulk(self, bulkobj: _AgnosticClientBulk) -> None: """Add this operation to the _AsyncClientBulk/_ClientBulk instance `bulkobj`.""" @@ -230,9 +230,9 @@ def __init__( """ super().__init__(filter, collation, hint, namespace) - def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> bool: """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" - bulkobj.add_delete( + return bulkobj.add_delete( self._filter, 1, collation=validate_collation_or_none(self._collation), @@ -291,9 +291,9 @@ def __init__( """ super().__init__(filter, collation, hint, namespace) - def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> bool: """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" - bulkobj.add_delete( + return bulkobj.add_delete( self._filter, 0, collation=validate_collation_or_none(self._collation), @@ -384,9 +384,9 @@ def __init__( self._collation = collation self._namespace = namespace - def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> bool: """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" - bulkobj.add_replace( + return bulkobj.add_replace( self._filter, self._doc, self._upsert, @@ -606,9 +606,9 @@ def __init__( """ super().__init__(filter, update, upsert, collation, array_filters, hint, namespace, sort) - def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> bool: """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" - bulkobj.add_update( + return bulkobj.add_update( self._filter, self._doc, False, @@ -687,9 +687,9 @@ def __init__( """ super().__init__(filter, update, upsert, collation, array_filters, hint, namespace, None) - def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> bool: """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" - bulkobj.add_update( + return bulkobj.add_update( self._filter, self._doc, True, diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index b92bb1a511..2734d8d3fc 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -26,7 +26,8 @@ from typing import ( TYPE_CHECKING, Any, - Generator, + Callable, + Iterable, Iterator, Mapping, Optional, @@ -111,9 +112,6 @@ def __init__( self.uses_hint_update = False self.uses_hint_delete = False self.uses_sort = False - self.is_retryable = True - self.retrying = False - self.started_retryable_write = False # Extra state so that we know where to pick up on a retry attempt. self.current_run = None self.next_run = None @@ -129,13 +127,32 @@ def bulk_ctx_class(self) -> Type[_BulkWriteContext]: self.is_encrypted = False return _BulkWriteContext - def add_insert(self, document: _DocumentOut) -> None: + @property + def is_retryable(self) -> bool: + if self.current_run: + return self.current_run.is_retryable + return True + + @property + def retrying(self) -> bool: + if self.current_run: + return self.current_run.retrying + return False + + @property + def started_retryable_write(self) -> bool: + if self.current_run: + return self.current_run.started_retryable_write + return False + + def add_insert(self, document: _DocumentOut) -> bool: """Add an insert document to the list of ops.""" validate_is_document_type("document", document) # Generate ObjectId client side. if not (isinstance(document, RawBSONDocument) or "_id" in document): document["_id"] = ObjectId() self.ops.append((_INSERT, document)) + return True def add_update( self, @@ -147,7 +164,7 @@ def add_update( array_filters: Optional[list[Mapping[str, Any]]] = None, hint: Union[str, dict[str, Any], None] = None, sort: Optional[Mapping[str, Any]] = None, - ) -> None: + ) -> bool: """Create an update document and add it to the list of ops.""" validate_ok_for_update(update) cmd: dict[str, Any] = {"q": selector, "u": update, "multi": multi} @@ -165,10 +182,12 @@ def add_update( if sort is not None: self.uses_sort = True cmd["sort"] = sort + + self.ops.append((_UPDATE, cmd)) if multi: # A bulk_write containing an update_many is not retryable. - self.is_retryable = False - self.ops.append((_UPDATE, cmd)) + return False + return True def add_replace( self, @@ -178,7 +197,7 @@ def add_replace( collation: Optional[Mapping[str, Any]] = None, hint: Union[str, dict[str, Any], None] = None, sort: Optional[Mapping[str, Any]] = None, - ) -> None: + ) -> bool: """Create a replace document and add it to the list of ops.""" validate_ok_for_replace(replacement) cmd: dict[str, Any] = {"q": selector, "u": replacement} @@ -194,6 +213,7 @@ def add_replace( self.uses_sort = True cmd["sort"] = sort self.ops.append((_UPDATE, cmd)) + return True def add_delete( self, @@ -201,7 +221,7 @@ def add_delete( limit: int, collation: Optional[Mapping[str, Any]] = None, hint: Union[str, dict[str, Any], None] = None, - ) -> None: + ) -> bool: """Create a delete document and add it to the list of ops.""" cmd: dict[str, Any] = {"q": selector, "limit": limit} if collation is not None: @@ -210,21 +230,24 @@ def add_delete( if hint is not None: self.uses_hint_delete = True cmd["hint"] = hint + + self.ops.append((_DELETE, cmd)) if limit == _DELETE_ALL: # A bulk_write containing a delete_many is not retryable. - self.is_retryable = False - self.ops.append((_DELETE, cmd)) + return False + return True - def gen_ordered(self, requests) -> Iterator[Optional[_Run]]: + def gen_ordered( + self, + requests: Iterable[Any], + process: Callable[[Union[_DocumentType, RawBSONDocument, _WriteOp]], bool], + ) -> Iterator[_Run]: """Generate batches of operations, batched by type of operation, in the order **provided**. """ run = None for idx, request in enumerate(requests): - try: - request._add_to_bulk(self) - except AttributeError: - raise TypeError(f"{request!r} is not a valid request") from None + retryable = process(request) (op_type, operation) = self.ops[idx] if run is None: run = _Run(op_type) @@ -232,22 +255,25 @@ def gen_ordered(self, requests) -> Iterator[Optional[_Run]]: yield run run = _Run(op_type) run.add(idx, operation) + run.is_retryable = run.is_retryable and retryable if run is None: raise InvalidOperation("No operations to execute") yield run - def gen_unordered(self, requests) -> Iterator[_Run]: + def gen_unordered( + self, + requests: Iterable[Any], + process: Callable[[Union[_DocumentType, RawBSONDocument, _WriteOp]], bool], + ) -> Iterator[_Run]: """Generate batches of operations, batched by type of operation, in arbitrary order. """ operations = [_Run(_INSERT), _Run(_UPDATE), _Run(_DELETE)] for idx, request in enumerate(requests): - try: - request._add_to_bulk(self) - except AttributeError: - raise TypeError(f"{request!r} is not a valid request") from None + retryable = process(request) (op_type, operation) = self.ops[idx] operations[op_type].add(idx, operation) + operations[op_type].is_retryable = operations[op_type].is_retryable and retryable if ( len(operations[_INSERT].ops) == 0 and len(operations[_UPDATE].ops) == 0 @@ -488,8 +514,8 @@ def _execute_command( session: Optional[ClientSession], conn: Connection, op_id: int, - retryable: bool, full_result: MutableMapping[str, Any], + validate: bool, final_write_concern: Optional[WriteConcern] = None, ) -> None: db_name = self.collection.database.name @@ -507,7 +533,7 @@ def _execute_command( last_run = False while run: - if not self.retrying: + if not run.retrying: self.next_run = next(generator, None) if self.next_run is None: last_run = True @@ -541,10 +567,10 @@ def _execute_command( if session: # Start a new retryable write unless one was already # started for this command. - if retryable and not self.started_retryable_write: + if run.is_retryable and not run.started_retryable_write: session._start_retryable_write() self.started_retryable_write = True - session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn) + session._apply_to(cmd, run.is_retryable, ReadPreference.PRIMARY, conn) conn.send_cluster_time(cmd, session, client) conn.add_server_api(cmd) # CSOT: apply timeout before encoding the command. @@ -552,9 +578,10 @@ def _execute_command( ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible in one command. + if validate: + self.validate_batch(conn, write_concern) if write_concern.acknowledged: result, to_send = self._execute_batch(bwc, cmd, ops, client) - # Retryable writeConcernErrors halt the execution of this run. wce = result.get("writeConcernError", {}) if wce.get("code", 0) in _RETRYABLE_ERROR_CODES: @@ -567,8 +594,8 @@ def _execute_command( _merge_command(run, full_result, run.idx_offset, result) # We're no longer in a retry once a command succeeds. - self.retrying = False - self.started_retryable_write = False + run.retrying = False + run.started_retryable_write = False if self.ordered and "writeErrors" in result: break @@ -606,7 +633,8 @@ def execute_command( op_id = _randint() def retryable_bulk( - session: Optional[ClientSession], conn: Connection, retryable: bool + session: Optional[ClientSession], + conn: Connection, ) -> None: self._execute_command( generator, @@ -614,25 +642,25 @@ def retryable_bulk( session, conn, op_id, - retryable, full_result, + validate=False, ) client = self.collection.database.client _ = client._retryable_write( - self.is_retryable, retryable_bulk, session, operation, bulk=self, # type: ignore[arg-type] operation_id=op_id, ) - if full_result["writeErrors"] or full_result["writeConcernErrors"]: _raise_bulk_write_error(full_result) return full_result - def execute_op_msg_no_results(self, conn: Connection, generator: Iterator[Any]) -> None: + def execute_op_msg_no_results( + self, conn: Connection, generator: Iterator[Any], write_concern: WriteConcern + ) -> None: """Execute write commands with OP_MSG and w=0 writeConcern, unordered.""" db_name = self.collection.database.name client = self.collection.database.client @@ -665,6 +693,7 @@ def execute_op_msg_no_results(self, conn: Connection, generator: Iterator[Any]) conn.add_server_api(cmd) ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible. + self.validate_batch(conn, write_concern) to_send = self._execute_batch_unack(bwc, cmd, ops, client) run.idx_offset += len(to_send) self.current_run = run = next(generator, None) @@ -698,12 +727,15 @@ def execute_command_no_results( None, conn, op_id, - False, full_result, + True, write_concern, ) - except OperationFailure: - pass + except OperationFailure as exc: + if "Cannot set bypass_document_validation with unacknowledged write concern" in str( + exc + ): + raise exc def execute_no_results( self, @@ -712,6 +744,11 @@ def execute_no_results( write_concern: WriteConcern, ) -> None: """Execute all operations, returning no results (w=0).""" + if self.ordered: + return self.execute_command_no_results(conn, generator, write_concern) + return self.execute_op_msg_no_results(conn, generator, write_concern) + + def validate_batch(self, conn: Connection, write_concern: WriteConcern) -> None: if self.uses_collation: raise ConfigurationError("Collation is unsupported for unacknowledged writes.") if self.uses_array_filters: @@ -736,13 +773,10 @@ def execute_no_results( "Cannot set bypass_document_validation with unacknowledged write concern" ) - if self.ordered: - return self.execute_command_no_results(conn, generator, write_concern) - return self.execute_op_msg_no_results(conn, generator) - def execute( self, - generator: Generator[_WriteOp[_DocumentType]], + generator: Iterable[Any], + process: Callable[[Union[_DocumentType, RawBSONDocument, _WriteOp]], bool], write_concern: WriteConcern, session: Optional[ClientSession], operation: str, @@ -755,9 +789,9 @@ def execute( session = _validate_session_write_concern(session, write_concern) if self.ordered: - generator = self.gen_ordered(generator) + generator = self.gen_ordered(generator, process) else: - generator = self.gen_unordered(generator) + generator = self.gen_unordered(generator, process) client = self.collection.database.client if not write_concern.acknowledged: diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index d73bfb2a2b..0b0d4190f9 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -116,6 +116,7 @@ def __init__( self.is_retryable = self.client.options.retry_writes self.retrying = False self.started_retryable_write = False + self.current_run = None @property def bulk_ctx_class(self) -> Type[_ClientBulkWriteContext]: @@ -486,7 +487,6 @@ def _execute_command( session: Optional[ClientSession], conn: Connection, op_id: int, - retryable: bool, full_result: MutableMapping[str, Any], final_write_concern: Optional[WriteConcern] = None, ) -> None: @@ -532,10 +532,10 @@ def _execute_command( if session: # Start a new retryable write unless one was already # started for this command. - if retryable and not self.started_retryable_write: + if self.is_retryable and not self.started_retryable_write: session._start_retryable_write() self.started_retryable_write = True - session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn) + session._apply_to(cmd, self.is_retryable, ReadPreference.PRIMARY, conn) conn.send_cluster_time(cmd, session, self.client) conn.add_server_api(cmd) # CSOT: apply timeout before encoding the command. @@ -562,7 +562,7 @@ def _execute_command( # Synthesize the full bulk result without modifying the # current one because this write operation may be retried. - if retryable and (retryable_top_level_error or retryable_network_error): + if self.is_retryable and (retryable_top_level_error or retryable_network_error): full = copy.deepcopy(full_result) _merge_command(self.ops, self.idx_offset, full, result) _throw_client_bulk_write_exception(full, self.verbose_results) @@ -581,7 +581,7 @@ def _execute_command( _merge_command(self.ops, self.idx_offset, full_result, result) break - if retryable: + if self.is_retryable: # Retryable writeConcernErrors halt the execution of this batch. wce = result.get("writeConcernError", {}) if wce.get("code", 0) in _RETRYABLE_ERROR_CODES: @@ -636,7 +636,6 @@ def execute_command( def retryable_bulk( session: Optional[ClientSession], conn: Connection, - retryable: bool, ) -> None: if conn.max_wire_version < 25: raise InvalidOperation( @@ -647,12 +646,10 @@ def retryable_bulk( session, conn, op_id, - retryable, full_result, ) self.client._retryable_write( - self.is_retryable, retryable_bulk, session, operation, diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index aaf2d7574f..dc52a24911 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -851,11 +851,12 @@ def _finish_transaction_with_retry(self, command_name: str) -> dict[str, Any]: """ def func( - _session: Optional[ClientSession], conn: Connection, _retryable: bool + _session: Optional[ClientSession], + conn: Connection, ) -> dict[str, Any]: return self._finish_transaction(conn, command_name) - return self._client._retry_internal(func, self, None, retryable=True, operation=_Op.ABORT) + return self._client._retry_internal(func, self, None, operation=_Op.ABORT) def _finish_transaction(self, conn: Connection, command_name: str) -> dict[str, Any]: self._transaction.attempt += 1 diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index 52c25de744..27b2a072d3 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -22,7 +22,6 @@ Any, Callable, ContextManager, - Generator, Generic, Iterable, Iterator, @@ -699,7 +698,7 @@ def _create( @_csot.apply def bulk_write( self, - requests: Sequence[_WriteOp[_DocumentType]] | Generator[_WriteOp[_DocumentType]], + requests: Iterable[_WriteOp], ordered: bool = True, bypass_document_validation: Optional[bool] = None, session: Optional[ClientSession] = None, @@ -784,7 +783,16 @@ def bulk_write( blk = _Bulk(self, ordered, bypass_document_validation, comment=comment, let=let) write_concern = self._write_concern_for(session) - bulk_api_result = blk.execute(requests, write_concern, session, _Op.INSERT) + + def process_for_bulk(request: _WriteOp) -> bool: + try: + return request._add_to_bulk(blk) + except AttributeError: + raise TypeError(f"{request!r} is not a valid request") from None + + bulk_api_result = blk.execute( + requests, process_for_bulk, write_concern, session, _Op.INSERT + ) if bulk_api_result is not None: return BulkWriteResult(bulk_api_result, True) return BulkWriteResult({}, False) @@ -801,17 +809,13 @@ def _insert_one( ) -> Any: """Internal helper for inserting a single document.""" write_concern = write_concern or self.write_concern - acknowledged = write_concern.acknowledged command = {"insert": self.name, "ordered": ordered, "documents": [doc]} if comment is not None: command["comment"] = comment - def _insert_command( - session: Optional[ClientSession], conn: Connection, retryable_write: bool - ) -> None: + def _insert_command(session: Optional[ClientSession], conn: Connection) -> None: if bypass_doc_val is not None: command["bypassDocumentValidation"] = bypass_doc_val - result = conn.command( self._database.name, command, @@ -819,14 +823,11 @@ def _insert_command( codec_options=self._write_response_codec_options, session=session, client=self._database.client, - retryable_write=retryable_write, ) _check_write_command_response(result) - self._database.client._retryable_write( - acknowledged, _insert_command, session, operation=_Op.INSERT - ) + self._database.client._retryable_write(_insert_command, session, operation=_Op.INSERT) if not isinstance(doc, RawBSONDocument): return doc.get("_id") @@ -955,20 +956,19 @@ def insert_many( raise TypeError("documents must be a non-empty list") inserted_ids: list[ObjectId] = [] - def gen() -> Iterator[tuple[int, Mapping[str, Any]]]: + def process_for_bulk(document: Union[_DocumentType, RawBSONDocument]) -> bool: """A generator that validates documents and handles _ids.""" - for document in documents: - common.validate_is_document_type("document", document) - if not isinstance(document, RawBSONDocument): - if "_id" not in document: - document["_id"] = ObjectId() # type: ignore[index] - inserted_ids.append(document["_id"]) - yield (message._INSERT, document) + common.validate_is_document_type("document", document) + if not isinstance(document, RawBSONDocument): + if "_id" not in document: + document["_id"] = ObjectId() # type: ignore[index] + inserted_ids.append(document["_id"]) + blk.ops.append((message._INSERT, document)) + return True write_concern = self._write_concern_for(session) blk = _Bulk(self, ordered, bypass_document_validation, comment=comment) - blk.ops = list(gen()) - blk.execute(write_concern, session, _Op.INSERT) + blk.execute(documents, process_for_bulk, write_concern, session, _Op.INSERT) return InsertManyResult(inserted_ids, write_concern.acknowledged) def _update( @@ -986,7 +986,6 @@ def _update( array_filters: Optional[Sequence[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None, session: Optional[ClientSession] = None, - retryable_write: bool = False, let: Optional[Mapping[str, Any]] = None, sort: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, @@ -1049,7 +1048,6 @@ def _update( codec_options=self._write_response_codec_options, session=session, client=self._database.client, - retryable_write=retryable_write, ) ).copy() _check_write_command_response(result) @@ -1089,7 +1087,7 @@ def _update_retryable( """Internal update / replace helper.""" def _update( - session: Optional[ClientSession], conn: Connection, retryable_write: bool + session: Optional[ClientSession], conn: Connection ) -> Optional[Mapping[str, Any]]: return self._update( conn, @@ -1105,14 +1103,12 @@ def _update( array_filters=array_filters, hint=hint, session=session, - retryable_write=retryable_write, let=let, sort=sort, comment=comment, ) return self._database.client._retryable_write( - (write_concern or self.write_concern).acknowledged and not multi, _update, session, operation, @@ -1502,7 +1498,6 @@ def _delete( collation: Optional[_CollationIn] = None, hint: Optional[_IndexKeyHint] = None, session: Optional[ClientSession] = None, - retryable_write: bool = False, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, ) -> Mapping[str, Any]: @@ -1542,7 +1537,6 @@ def _delete( codec_options=self._write_response_codec_options, session=session, client=self._database.client, - retryable_write=retryable_write, ) _check_write_command_response(result) return result @@ -1562,9 +1556,7 @@ def _delete_retryable( ) -> Mapping[str, Any]: """Internal delete helper.""" - def _delete( - session: Optional[ClientSession], conn: Connection, retryable_write: bool - ) -> Mapping[str, Any]: + def _delete(session: Optional[ClientSession], conn: Connection) -> Mapping[str, Any]: return self._delete( conn, criteria, @@ -1575,13 +1567,11 @@ def _delete( collation=collation, hint=hint, session=session, - retryable_write=retryable_write, let=let, comment=comment, ) return self._database.client._retryable_write( - (write_concern or self.write_concern).acknowledged and not multi, _delete, session, operation=_Op.DELETE, @@ -3219,9 +3209,7 @@ def _find_and_modify( write_concern = self._write_concern_for_cmd(cmd, session) - def _find_and_modify_helper( - session: Optional[ClientSession], conn: Connection, retryable_write: bool - ) -> Any: + def _find_and_modify_helper(session: Optional[ClientSession], conn: Connection) -> Any: acknowledged = write_concern.acknowledged if array_filters is not None: if not acknowledged: @@ -3246,7 +3234,6 @@ def _find_and_modify_helper( write_concern=write_concern, collation=collation, session=session, - retryable_write=retryable_write, user_fields=_FIND_AND_MODIFY_DOC_FIELDS, ) _check_write_command_response(out) @@ -3254,7 +3241,6 @@ def _find_and_modify_helper( return out.get("value") return self._database.client._retryable_write( - write_concern.acknowledged, _find_and_modify_helper, session, operation=_Op.FIND_AND_MODIFY, diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 2d8d6d730b..3c657c214c 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -141,7 +141,7 @@ T = TypeVar("T") -_WriteCall = Callable[[Optional["ClientSession"], "Connection", bool], T] +_WriteCall = Callable[[Optional["ClientSession"], "Connection"], T] _ReadCall = Callable[ [Optional["ClientSession"], "Server", "Connection", _ServerMode], T, @@ -1888,7 +1888,6 @@ def _cmd( def _retry_with_session( self, - retryable: bool, func: _WriteCall[T], session: Optional[ClientSession], bulk: Optional[Union[_Bulk, _ClientBulk]], @@ -1904,15 +1903,11 @@ def _retry_with_session( """ # Ensure that the options supports retry_writes and there is a valid session not in # transaction, otherwise, we will not support retry behavior for this txn. - retryable = bool( - retryable and self.options.retry_writes and session and not session.in_transaction - ) return self._retry_internal( func=func, session=session, bulk=bulk, operation=operation, - retryable=retryable, operation_id=operation_id, ) @@ -1926,7 +1921,6 @@ def _retry_internal( is_read: bool = False, address: Optional[_Address] = None, read_pref: Optional[_ServerMode] = None, - retryable: bool = False, operation_id: Optional[int] = None, ) -> T: """Internal retryable helper for all client transactions. @@ -1951,7 +1945,6 @@ def _retry_internal( session=session, read_pref=read_pref, address=address, - retryable=retryable, operation_id=operation_id, ).run() @@ -1994,13 +1987,11 @@ def _retryable_read( is_read=True, address=address, read_pref=read_pref, - retryable=retryable, operation_id=operation_id, ) def _retryable_write( self, - retryable: bool, func: _WriteCall[T], session: Optional[ClientSession], operation: str, @@ -2021,7 +2012,7 @@ def _retryable_write( :param bulk: bulk abstraction to execute operations in bulk, defaults to None """ with self._tmp_session(session) as s: - return self._retry_with_session(retryable, func, s, bulk, operation, operation_id) + return self._retry_with_session(func, s, bulk, operation, operation_id) def _cleanup_cursor_no_lock( self, @@ -2648,7 +2639,6 @@ def __init__( session: Optional[ClientSession] = None, read_pref: Optional[_ServerMode] = None, address: Optional[_Address] = None, - retryable: bool = False, operation_id: Optional[int] = None, ): self._last_error: Optional[Exception] = None @@ -2660,7 +2650,7 @@ def __init__( self._bulk = bulk self._session = session self._is_read = is_read - self._retryable = retryable + self._retryable = True self._read_pref = read_pref self._server_selector: Callable[[Selection], Selection] = ( read_pref if is_read else writable_server_selector # type: ignore @@ -2671,6 +2661,11 @@ def __init__( self._operation = operation self._operation_id = operation_id + def _bulk_retryable(self) -> bool: + if self._bulk is not None and self._bulk.current_run is not None: + return self._bulk.current_run.is_retryable + return True + def run(self) -> T: """Runs the supplied func() and attempts a retry @@ -2681,10 +2676,15 @@ def run(self) -> T: # Increment the transaction id up front to ensure any retry attempt # will use the proper txnNumber, even if server or socket selection # fails before the command can be sent. - if self._is_session_state_retryable() and self._retryable and not self._is_read: + if ( + self._is_session_state_retryable() + and self._retryable + and self._bulk_retryable() + and not self._is_read + ): self._session._start_retryable_write() # type: ignore - if self._bulk: - self._bulk.started_retryable_write = True + if self._bulk and self._bulk.current_run: + self._bulk.current_run.started_retryable_write = True while True: self._check_last_error(check_csot=True) @@ -2717,7 +2717,7 @@ def run(self) -> T: # Specialized catch on write operation if not self._is_read: - if not self._retryable: + if not self._retryable and not self._bulk_retryable(): raise if isinstance(exc, ClientBulkWriteException) and exc.error: retryable_write_error_exc = isinstance( @@ -2734,7 +2734,10 @@ def run(self) -> T: else: raise if self._bulk: - self._bulk.retrying = True + if self._bulk.current_run: + self._bulk.current_run.retrying = True + else: + self._bulk.retrying = True else: self._retrying = True if not exc.has_error_label("NoWritesPerformed"): @@ -2747,11 +2750,19 @@ def run(self) -> T: def _is_not_eligible_for_retry(self) -> bool: """Checks if the exchange is not eligible for retry""" - return not self._retryable or (self._is_retrying() and not self._multiple_retries) + return ( + not self._retryable + or not self._bulk_retryable() + or (self._is_retrying() and not self._multiple_retries) + ) def _is_retrying(self) -> bool: """Checks if the exchange is currently undergoing a retry""" - return self._bulk.retrying if self._bulk else self._retrying + return ( + self._bulk.current_run.retrying + if self._bulk is not None and self._bulk.current_run is not None + else self._retrying + ) def _is_session_state_retryable(self) -> bool: """Checks if provided session is eligible for retry @@ -2811,9 +2822,11 @@ def _write(self) -> T: # not support sessions raise the last error. self._check_last_error() self._retryable = False - return self._func(self._session, conn, self._retryable) # type: ignore + if self._bulk and self._bulk.current_run: + self._bulk.current_run.is_retryable = False + return self._func(self._session, conn) # type: ignore except PyMongoError as exc: - if not self._retryable: + if not self._retryable or not self._bulk_retryable(): raise # Add the RetryableWriteError label, if applicable. _add_retryable_write_error(exc, max_wire_version, is_mongos) @@ -2830,7 +2843,7 @@ def _read(self) -> T: conn, read_pref, ): - if self._retrying and not self._retryable: + if self._retrying and (not self._retryable or not self._bulk_retryable()): self._check_last_error() return self._func(self._session, self._server, conn, read_pref) # type: ignore diff --git a/test/asynchronous/test_bulk.py b/test/asynchronous/test_bulk.py index 3becea0777..4d2338eae2 100644 --- a/test/asynchronous/test_bulk.py +++ b/test/asynchronous/test_bulk.py @@ -353,11 +353,6 @@ async def test_bulk_write_no_results(self): self.assertRaises(InvalidOperation, lambda: result.upserted_ids) async def test_bulk_write_invalid_arguments(self): - # The requests argument must be a list. - generator = (InsertOne[dict]({}) for _ in range(10)) - with self.assertRaises(TypeError): - await self.coll.bulk_write(generator) # type: ignore[arg-type] - # Document is not wrapped in a bulk write operation. with self.assertRaises(TypeError): await self.coll.bulk_write([{}]) # type: ignore[list-item] diff --git a/test/test_bulk.py b/test/test_bulk.py index 3e631e661f..9696f6da1d 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -353,11 +353,6 @@ def test_bulk_write_no_results(self): self.assertRaises(InvalidOperation, lambda: result.upserted_ids) def test_bulk_write_invalid_arguments(self): - # The requests argument must be a list. - generator = (InsertOne[dict]({}) for _ in range(10)) - with self.assertRaises(TypeError): - self.coll.bulk_write(generator) # type: ignore[arg-type] - # Document is not wrapped in a bulk write operation. with self.assertRaises(TypeError): self.coll.bulk_write([{}]) # type: ignore[list-item] From a0ef0549a9e09ce351df7fa6a34afa09ed252a16 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Mon, 21 Apr 2025 15:31:17 -0700 Subject: [PATCH 3/3] retrying vars back in bulk --- pymongo/asynchronous/bulk.py | 47 +++++++++++++++------------- pymongo/asynchronous/mongo_client.py | 23 +++++--------- pymongo/bulk_shared.py | 3 -- pymongo/synchronous/bulk.py | 47 +++++++++++++++------------- pymongo/synchronous/mongo_client.py | 23 +++++--------- 5 files changed, 66 insertions(+), 77 deletions(-) diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index 630be7c25e..a98c2b99c1 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -112,6 +112,9 @@ def __init__( self.uses_hint_update = False self.uses_hint_delete = False self.uses_sort = False + self.is_retryable = True + self.retrying = False + self.started_retryable_write = False # Extra state so that we know where to pick up on a retry attempt. self.current_run = None self.next_run = None @@ -127,23 +130,23 @@ def bulk_ctx_class(self) -> Type[_BulkWriteContext]: self.is_encrypted = False return _BulkWriteContext - @property - def is_retryable(self) -> bool: - if self.current_run: - return self.current_run.is_retryable - return True - - @property - def retrying(self) -> bool: - if self.current_run: - return self.current_run.retrying - return False - - @property - def started_retryable_write(self) -> bool: - if self.current_run: - return self.current_run.started_retryable_write - return False + # @property + # def is_retryable(self) -> bool: + # if self.current_run: + # return self.current_run.is_retryable + # return True + # + # @property + # def retrying(self) -> bool: + # if self.current_run: + # return self.current_run.retrying + # return False + # + # @property + # def started_retryable_write(self) -> bool: + # if self.current_run: + # return self.current_run.started_retryable_write + # return False def add_insert(self, document: _DocumentOut) -> bool: """Add an insert document to the list of ops.""" @@ -255,7 +258,7 @@ def gen_ordered( yield run run = _Run(op_type) run.add(idx, operation) - run.is_retryable = run.is_retryable and retryable + self.is_retryable = self.is_retryable and retryable if run is None: raise InvalidOperation("No operations to execute") yield run @@ -273,7 +276,7 @@ def gen_unordered( retryable = process(request) (op_type, operation) = self.ops[idx] operations[op_type].add(idx, operation) - operations[op_type].is_retryable = operations[op_type].is_retryable and retryable + self.is_retryable = self.is_retryable and retryable if ( len(operations[_INSERT].ops) == 0 and len(operations[_UPDATE].ops) == 0 @@ -533,7 +536,7 @@ async def _execute_command( last_run = False while run: - if not run.retrying: + if not self.retrying: self.next_run = next(generator, None) if self.next_run is None: last_run = True @@ -567,10 +570,10 @@ async def _execute_command( if session: # Start a new retryable write unless one was already # started for this command. - if run.is_retryable and not run.started_retryable_write: + if self.is_retryable and not self.started_retryable_write: session._start_retryable_write() self.started_retryable_write = True - session._apply_to(cmd, run.is_retryable, ReadPreference.PRIMARY, conn) + session._apply_to(cmd, self.is_retryable, ReadPreference.PRIMARY, conn) conn.send_cluster_time(cmd, session, client) conn.add_server_api(cmd) # CSOT: apply timeout before encoding the command. diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 4c8230cea9..1675ce801d 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -2674,8 +2674,8 @@ def __init__( self._operation_id = operation_id def _bulk_retryable(self) -> bool: - if self._bulk is not None and self._bulk.current_run is not None: - return self._bulk.current_run.is_retryable + if self._bulk is not None: + return self._bulk.is_retryable return True async def run(self) -> T: @@ -2695,8 +2695,8 @@ async def run(self) -> T: and not self._is_read ): self._session._start_retryable_write() # type: ignore - if self._bulk and self._bulk.current_run: - self._bulk.current_run.started_retryable_write = True + if self._bulk: + self._bulk.started_retryable_write = True while True: self._check_last_error(check_csot=True) @@ -2746,10 +2746,7 @@ async def run(self) -> T: else: raise if self._bulk: - if self._bulk.current_run: - self._bulk.current_run.retrying = True - else: - self._bulk.retrying = True + self._bulk.retrying = True else: self._retrying = True if not exc.has_error_label("NoWritesPerformed"): @@ -2770,11 +2767,7 @@ def _is_not_eligible_for_retry(self) -> bool: def _is_retrying(self) -> bool: """Checks if the exchange is currently undergoing a retry""" - return ( - self._bulk.current_run.retrying - if self._bulk is not None and self._bulk.current_run is not None - else self._retrying - ) + return self._bulk.retrying if self._bulk is not None else self._retrying def _is_session_state_retryable(self) -> bool: """Checks if provided session is eligible for retry @@ -2834,8 +2827,8 @@ async def _write(self) -> T: # not support sessions raise the last error. self._check_last_error() self._retryable = False - if self._bulk and self._bulk.current_run: - self._bulk.current_run.is_retryable = False + if self._bulk: + self._bulk.is_retryable = False return await self._func(self._session, conn) # type: ignore except PyMongoError as exc: if not self._retryable or not self._bulk_retryable(): diff --git a/pymongo/bulk_shared.py b/pymongo/bulk_shared.py index b157edd2e2..9276419d8a 100644 --- a/pymongo/bulk_shared.py +++ b/pymongo/bulk_shared.py @@ -50,9 +50,6 @@ def __init__(self, op_type: int) -> None: self.index_map: list[int] = [] self.ops: list[Any] = [] self.idx_offset: int = 0 - self.is_retryable = True - self.retrying = False - self.started_retryable_write = False def index(self, idx: int) -> int: """Get the original index of an operation in this run. diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index 2734d8d3fc..c3323ed841 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -112,6 +112,9 @@ def __init__( self.uses_hint_update = False self.uses_hint_delete = False self.uses_sort = False + self.is_retryable = True + self.retrying = False + self.started_retryable_write = False # Extra state so that we know where to pick up on a retry attempt. self.current_run = None self.next_run = None @@ -127,23 +130,23 @@ def bulk_ctx_class(self) -> Type[_BulkWriteContext]: self.is_encrypted = False return _BulkWriteContext - @property - def is_retryable(self) -> bool: - if self.current_run: - return self.current_run.is_retryable - return True - - @property - def retrying(self) -> bool: - if self.current_run: - return self.current_run.retrying - return False - - @property - def started_retryable_write(self) -> bool: - if self.current_run: - return self.current_run.started_retryable_write - return False + # @property + # def is_retryable(self) -> bool: + # if self.current_run: + # return self.current_run.is_retryable + # return True + # + # @property + # def retrying(self) -> bool: + # if self.current_run: + # return self.current_run.retrying + # return False + # + # @property + # def started_retryable_write(self) -> bool: + # if self.current_run: + # return self.current_run.started_retryable_write + # return False def add_insert(self, document: _DocumentOut) -> bool: """Add an insert document to the list of ops.""" @@ -255,7 +258,7 @@ def gen_ordered( yield run run = _Run(op_type) run.add(idx, operation) - run.is_retryable = run.is_retryable and retryable + self.is_retryable = self.is_retryable and retryable if run is None: raise InvalidOperation("No operations to execute") yield run @@ -273,7 +276,7 @@ def gen_unordered( retryable = process(request) (op_type, operation) = self.ops[idx] operations[op_type].add(idx, operation) - operations[op_type].is_retryable = operations[op_type].is_retryable and retryable + self.is_retryable = self.is_retryable and retryable if ( len(operations[_INSERT].ops) == 0 and len(operations[_UPDATE].ops) == 0 @@ -533,7 +536,7 @@ def _execute_command( last_run = False while run: - if not run.retrying: + if not self.retrying: self.next_run = next(generator, None) if self.next_run is None: last_run = True @@ -567,10 +570,10 @@ def _execute_command( if session: # Start a new retryable write unless one was already # started for this command. - if run.is_retryable and not run.started_retryable_write: + if self.is_retryable and not self.started_retryable_write: session._start_retryable_write() self.started_retryable_write = True - session._apply_to(cmd, run.is_retryable, ReadPreference.PRIMARY, conn) + session._apply_to(cmd, self.is_retryable, ReadPreference.PRIMARY, conn) conn.send_cluster_time(cmd, session, client) conn.add_server_api(cmd) # CSOT: apply timeout before encoding the command. diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 3c657c214c..695e9be8b1 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -2662,8 +2662,8 @@ def __init__( self._operation_id = operation_id def _bulk_retryable(self) -> bool: - if self._bulk is not None and self._bulk.current_run is not None: - return self._bulk.current_run.is_retryable + if self._bulk is not None: + return self._bulk.is_retryable return True def run(self) -> T: @@ -2683,8 +2683,8 @@ def run(self) -> T: and not self._is_read ): self._session._start_retryable_write() # type: ignore - if self._bulk and self._bulk.current_run: - self._bulk.current_run.started_retryable_write = True + if self._bulk: + self._bulk.started_retryable_write = True while True: self._check_last_error(check_csot=True) @@ -2734,10 +2734,7 @@ def run(self) -> T: else: raise if self._bulk: - if self._bulk.current_run: - self._bulk.current_run.retrying = True - else: - self._bulk.retrying = True + self._bulk.retrying = True else: self._retrying = True if not exc.has_error_label("NoWritesPerformed"): @@ -2758,11 +2755,7 @@ def _is_not_eligible_for_retry(self) -> bool: def _is_retrying(self) -> bool: """Checks if the exchange is currently undergoing a retry""" - return ( - self._bulk.current_run.retrying - if self._bulk is not None and self._bulk.current_run is not None - else self._retrying - ) + return self._bulk.retrying if self._bulk is not None else self._retrying def _is_session_state_retryable(self) -> bool: """Checks if provided session is eligible for retry @@ -2822,8 +2815,8 @@ def _write(self) -> T: # not support sessions raise the last error. self._check_last_error() self._retryable = False - if self._bulk and self._bulk.current_run: - self._bulk.current_run.is_retryable = False + if self._bulk: + self._bulk.is_retryable = False return self._func(self._session, conn) # type: ignore except PyMongoError as exc: if not self._retryable or not self._bulk_retryable():