Skip to content

PYTHON-1752 bulk_write should be able to accept a generator #2262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: master
Choose a base branch
from
Draft
110 changes: 82 additions & 28 deletions pymongo/asynchronous/bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Iterator,
Mapping,
Optional,
Expand Down Expand Up @@ -72,7 +74,7 @@
from pymongo.write_concern import WriteConcern

if TYPE_CHECKING:
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.collection import AsyncCollection, _WriteOp
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline
Expand Down Expand Up @@ -128,13 +130,14 @@ def bulk_ctx_class(self) -> Type[_BulkWriteContext]:
self.is_encrypted = False
return _BulkWriteContext

def add_insert(self, document: _DocumentOut) -> None:
def add_insert(self, document: _DocumentOut) -> bool:
"""Add an insert document to the list of ops."""
validate_is_document_type("document", document)
# Generate ObjectId client side.
if not (isinstance(document, RawBSONDocument) or "_id" in document):
document["_id"] = ObjectId()
self.ops.append((_INSERT, document))
return True

def add_update(
self,
Expand All @@ -146,7 +149,7 @@ def add_update(
array_filters: Optional[list[Mapping[str, Any]]] = None,
hint: Union[str, dict[str, Any], None] = None,
sort: Optional[Mapping[str, Any]] = None,
) -> None:
) -> bool:
"""Create an update document and add it to the list of ops."""
validate_ok_for_update(update)
cmd: dict[str, Any] = {"q": selector, "u": update, "multi": multi}
Expand All @@ -164,10 +167,12 @@ def add_update(
if sort is not None:
self.uses_sort = True
cmd["sort"] = sort

self.ops.append((_UPDATE, cmd))
if multi:
# A bulk_write containing an update_many is not retryable.
self.is_retryable = False
self.ops.append((_UPDATE, cmd))
return False
return True

def add_replace(
self,
Expand All @@ -177,7 +182,7 @@ def add_replace(
collation: Optional[Mapping[str, Any]] = None,
hint: Union[str, dict[str, Any], None] = None,
sort: Optional[Mapping[str, Any]] = None,
) -> None:
) -> bool:
"""Create a replace document and add it to the list of ops."""
validate_ok_for_replace(replacement)
cmd: dict[str, Any] = {"q": selector, "u": replacement}
Expand All @@ -193,14 +198,15 @@ def add_replace(
self.uses_sort = True
cmd["sort"] = sort
self.ops.append((_UPDATE, cmd))
return True

def add_delete(
self,
selector: Mapping[str, Any],
limit: int,
collation: Optional[Mapping[str, Any]] = None,
hint: Union[str, dict[str, Any], None] = None,
) -> None:
) -> bool:
"""Create a delete document and add it to the list of ops."""
cmd: dict[str, Any] = {"q": selector, "limit": limit}
if collation is not None:
Expand All @@ -209,33 +215,63 @@ def add_delete(
if hint is not None:
self.uses_hint_delete = True
cmd["hint"] = hint

self.ops.append((_DELETE, cmd))
if limit == _DELETE_ALL:
# A bulk_write containing a delete_many is not retryable.
self.is_retryable = False
self.ops.append((_DELETE, cmd))
return False
return True

def gen_ordered(self) -> Iterator[Optional[_Run]]:
def gen_ordered(
self,
requests: Iterable[Any],
process: Union[
Callable[[_WriteOp], bool], Callable[[Union[_DocumentType, RawBSONDocument]], bool]
],
) -> Iterator[_Run]:
"""Generate batches of operations, batched by type of
operation, in the order **provided**.
"""
run = None
for idx, (op_type, operation) in enumerate(self.ops):
ctr = 0
for idx, request in enumerate(requests):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the goal of this ticket is to avoid inflating the whole generator upfront and only iterate requests as they are needed at the encoding step. For example:

coll.bulk_write((InsertOne({'x': 'large'*1024*1024}) for _ in range(1_000_000))

If we inflate all at once like we do here, then that code will need to allocate all 1 million documents at once.

retryable = process(request)
(op_type, operation) = self.ops[idx]
if run is None:
run = _Run(op_type)
elif run.op_type != op_type:
elif run.op_type != op_type or ctr >= common.MAX_WRITE_BATCH_SIZE // 200:
yield run
ctr = 0
run = _Run(op_type)
ctr += 1
run.add(idx, operation)
run.is_retryable = run.is_retryable and retryable
if run is None:
raise InvalidOperation("No operations to execute")
yield run

def gen_unordered(self) -> Iterator[_Run]:
def gen_unordered(
self,
requests: Iterable[Any],
process: Union[
Callable[[_WriteOp], bool], Callable[[Union[_DocumentType, RawBSONDocument]], bool]
],
) -> Iterator[_Run]:
"""Generate batches of operations, batched by type of
operation, in arbitrary order.
"""
operations = [_Run(_INSERT), _Run(_UPDATE), _Run(_DELETE)]
for idx, (op_type, operation) in enumerate(self.ops):
for idx, request in enumerate(requests):
retryable = process(request)
(op_type, operation) = self.ops[idx]
operations[op_type].add(idx, operation)

operations[op_type].is_retryable = operations[op_type].is_retryable and retryable
if (
len(operations[_INSERT].ops) == 0
and len(operations[_UPDATE].ops) == 0
and len(operations[_DELETE].ops) == 0
):
raise InvalidOperation("No operations to execute")
for run in operations:
if run.ops:
yield run
Expand Down Expand Up @@ -472,6 +508,7 @@ async def _execute_command(
op_id: int,
retryable: bool,
full_result: MutableMapping[str, Any],
validate: bool,
final_write_concern: Optional[WriteConcern] = None,
) -> None:
db_name = self.collection.database.name
Expand All @@ -489,6 +526,7 @@ async def _execute_command(
last_run = False

while run:
self.is_retryable = run.is_retryable
if not self.retrying:
self.next_run = next(generator, None)
if self.next_run is None:
Expand Down Expand Up @@ -523,17 +561,21 @@ async def _execute_command(
if session:
# Start a new retryable write unless one was already
# started for this command.
if retryable and not self.started_retryable_write:
if retryable and self.is_retryable and not self.started_retryable_write:
session._start_retryable_write()
self.started_retryable_write = True
session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn)
session._apply_to(
cmd, retryable and self.is_retryable, ReadPreference.PRIMARY, conn
)
conn.send_cluster_time(cmd, session, client)
conn.add_server_api(cmd)
# CSOT: apply timeout before encoding the command.
conn.apply_timeout(client, cmd)
ops = islice(run.ops, run.idx_offset, None)

# Run as many ops as possible in one command.
if validate:
await self.validate_batch(conn, write_concern)
if write_concern.acknowledged:
result, to_send = await self._execute_batch(bwc, cmd, ops, client)

Expand Down Expand Up @@ -565,6 +607,9 @@ async def _execute_command(
break
# Reset our state
self.current_run = run = self.next_run
import gc

gc.collect()

async def execute_command(
self,
Expand Down Expand Up @@ -598,6 +643,7 @@ async def retryable_bulk(
op_id,
retryable,
full_result,
validate=False,
)

client = self.collection.database.client
Expand All @@ -615,7 +661,7 @@ async def retryable_bulk(
return full_result

async def execute_op_msg_no_results(
self, conn: AsyncConnection, generator: Iterator[Any]
self, conn: AsyncConnection, generator: Iterator[Any], write_concern: WriteConcern
) -> None:
"""Execute write commands with OP_MSG and w=0 writeConcern, unordered."""
db_name = self.collection.database.name
Expand Down Expand Up @@ -649,6 +695,7 @@ async def execute_op_msg_no_results(
conn.add_server_api(cmd)
ops = islice(run.ops, run.idx_offset, None)
# Run as many ops as possible.
await self.validate_batch(conn, write_concern)
to_send = await self._execute_batch_unack(bwc, cmd, ops, client)
run.idx_offset += len(to_send)
self.current_run = run = next(generator, None)
Expand Down Expand Up @@ -684,10 +731,14 @@ async def execute_command_no_results(
op_id,
False,
full_result,
True,
write_concern,
)
except OperationFailure:
pass
except OperationFailure as exc:
if "Cannot set bypass_document_validation with unacknowledged write concern" in str(
exc
):
raise exc

async def execute_no_results(
self,
Expand All @@ -696,6 +747,11 @@ async def execute_no_results(
write_concern: WriteConcern,
) -> None:
"""Execute all operations, returning no results (w=0)."""
if self.ordered:
return await self.execute_command_no_results(conn, generator, write_concern)
return await self.execute_op_msg_no_results(conn, generator, write_concern)

async def validate_batch(self, conn: AsyncConnection, write_concern: WriteConcern) -> None:
if self.uses_collation:
raise ConfigurationError("Collation is unsupported for unacknowledged writes.")
if self.uses_array_filters:
Expand All @@ -720,29 +776,27 @@ async def execute_no_results(
"Cannot set bypass_document_validation with unacknowledged write concern"
)

if self.ordered:
return await self.execute_command_no_results(conn, generator, write_concern)
return await self.execute_op_msg_no_results(conn, generator)

async def execute(
self,
generator: Iterable[Any],
process: Union[
Callable[[_WriteOp], bool], Callable[[Union[_DocumentType, RawBSONDocument]], bool]
],
write_concern: WriteConcern,
session: Optional[AsyncClientSession],
operation: str,
) -> Any:
"""Execute operations."""
if not self.ops:
raise InvalidOperation("No operations to execute")
if self.executed:
raise InvalidOperation("Bulk operations can only be executed once.")
self.executed = True
write_concern = write_concern or self.collection.write_concern
session = _validate_session_write_concern(session, write_concern)

if self.ordered:
generator = self.gen_ordered()
generator = self.gen_ordered(generator, process)
else:
generator = self.gen_unordered()
generator = self.gen_unordered(generator, process)

client = self.collection.database.client
if not write_concern.acknowledged:
Expand Down
1 change: 1 addition & 0 deletions pymongo/asynchronous/client_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def __init__(
self.is_retryable = self.client.options.retry_writes
self.retrying = False
self.started_retryable_write = False
self.current_run = None

@property
def bulk_ctx_class(self) -> Type[_ClientBulkWriteContext]:
Expand Down
35 changes: 19 additions & 16 deletions pymongo/asynchronous/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ async def _create(
@_csot.apply
async def bulk_write(
self,
requests: Sequence[_WriteOp[_DocumentType]],
requests: Iterable[_WriteOp],
ordered: bool = True,
bypass_document_validation: Optional[bool] = None,
session: Optional[AsyncClientSession] = None,
Expand Down Expand Up @@ -779,17 +779,21 @@ async def bulk_write(

.. versionadded:: 3.0
"""
common.validate_list("requests", requests)
common.validate_list_or_generator("requests", requests)

blk = _AsyncBulk(self, ordered, bypass_document_validation, comment=comment, let=let)
for request in requests:

write_concern = self._write_concern_for(session)

def process_for_bulk(request: _WriteOp) -> bool:
try:
request._add_to_bulk(blk)
return request._add_to_bulk(blk)
except AttributeError:
raise TypeError(f"{request!r} is not a valid request") from None

write_concern = self._write_concern_for(session)
bulk_api_result = await blk.execute(write_concern, session, _Op.INSERT)
bulk_api_result = await blk.execute(
requests, process_for_bulk, write_concern, session, _Op.INSERT
)
if bulk_api_result is not None:
return BulkWriteResult(bulk_api_result, True)
return BulkWriteResult({}, False)
Expand Down Expand Up @@ -960,20 +964,19 @@ async def insert_many(
raise TypeError("documents must be a non-empty list")
inserted_ids: list[ObjectId] = []

def gen() -> Iterator[tuple[int, Mapping[str, Any]]]:
def process_for_bulk(document: Union[_DocumentType, RawBSONDocument]) -> bool:
"""A generator that validates documents and handles _ids."""
for document in documents:
common.validate_is_document_type("document", document)
if not isinstance(document, RawBSONDocument):
if "_id" not in document:
document["_id"] = ObjectId() # type: ignore[index]
inserted_ids.append(document["_id"])
yield (message._INSERT, document)
common.validate_is_document_type("document", document)
if not isinstance(document, RawBSONDocument):
if "_id" not in document:
document["_id"] = ObjectId() # type: ignore[index]
inserted_ids.append(document["_id"])
blk.ops.append((message._INSERT, document))
return True

write_concern = self._write_concern_for(session)
blk = _AsyncBulk(self, ordered, bypass_document_validation, comment=comment)
blk.ops = list(gen())
await blk.execute(write_concern, session, _Op.INSERT)
await blk.execute(documents, process_for_bulk, write_concern, session, _Op.INSERT)
return InsertManyResult(inserted_ids, write_concern.acknowledged)

async def _update(
Expand Down
Loading
Loading