Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PYTHON-1752 bulk_write should be able to accept a generator #2262

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions pymongo/asynchronous/bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from typing import (
TYPE_CHECKING,
Any,
Generator,
Iterator,
Mapping,
Optional,
Expand Down Expand Up @@ -72,7 +73,7 @@
from pymongo.write_concern import WriteConcern

if TYPE_CHECKING:
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.collection import AsyncCollection, _WriteOp
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline
Expand Down Expand Up @@ -214,28 +215,45 @@ def add_delete(
self.is_retryable = False
self.ops.append((_DELETE, cmd))

def gen_ordered(self) -> Iterator[Optional[_Run]]:
def gen_ordered(self, requests) -> Iterator[Optional[_Run]]:
"""Generate batches of operations, batched by type of
operation, in the order **provided**.
"""
run = None
for idx, (op_type, operation) in enumerate(self.ops):
for idx, request in enumerate(requests):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the goal of this ticket is to avoid inflating the whole generator upfront and only iterate requests as they are needed at the encoding step. For example:

coll.bulk_write((InsertOne({'x': 'large'*1024*1024}) for _ in range(1_000_000))

If we inflate all at once like we do here, then that code will need to allocate all 1 million documents at once.

try:
request._add_to_bulk(self)
except AttributeError:
raise TypeError(f"{request!r} is not a valid request") from None
(op_type, operation) = self.ops[idx]
if run is None:
run = _Run(op_type)
elif run.op_type != op_type:
yield run
run = _Run(op_type)
run.add(idx, operation)
if run is None:
raise InvalidOperation("No operations to execute")
yield run

def gen_unordered(self) -> Iterator[_Run]:
def gen_unordered(self, requests) -> Iterator[_Run]:
"""Generate batches of operations, batched by type of
operation, in arbitrary order.
"""
operations = [_Run(_INSERT), _Run(_UPDATE), _Run(_DELETE)]
for idx, (op_type, operation) in enumerate(self.ops):
for idx, request in enumerate(requests):
try:
request._add_to_bulk(self)
except AttributeError:
raise TypeError(f"{request!r} is not a valid request") from None
(op_type, operation) = self.ops[idx]
operations[op_type].add(idx, operation)

if (
len(operations[_INSERT].ops) == 0
and len(operations[_UPDATE].ops) == 0
and len(operations[_DELETE].ops) == 0
):
raise InvalidOperation("No operations to execute")
for run in operations:
if run.ops:
yield run
Expand Down Expand Up @@ -726,23 +744,22 @@ async def execute_no_results(

async def execute(
self,
generator: Generator[_WriteOp[_DocumentType]],
write_concern: WriteConcern,
session: Optional[AsyncClientSession],
operation: str,
) -> Any:
"""Execute operations."""
if not self.ops:
raise InvalidOperation("No operations to execute")
if self.executed:
raise InvalidOperation("Bulk operations can only be executed once.")
self.executed = True
write_concern = write_concern or self.collection.write_concern
session = _validate_session_write_concern(session, write_concern)

if self.ordered:
generator = self.gen_ordered()
generator = self.gen_ordered(generator)
else:
generator = self.gen_unordered()
generator = self.gen_unordered(generator)

client = self.collection.database.client
if not write_concern.acknowledged:
Expand Down
12 changes: 4 additions & 8 deletions pymongo/asynchronous/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
AsyncContextManager,
Callable,
Coroutine,
Generator,
Generic,
Iterable,
Iterator,
Expand Down Expand Up @@ -699,7 +700,7 @@ async def _create(
@_csot.apply
async def bulk_write(
self,
requests: Sequence[_WriteOp[_DocumentType]],
requests: Sequence[_WriteOp[_DocumentType]] | Generator[_WriteOp[_DocumentType]],
ordered: bool = True,
bypass_document_validation: Optional[bool] = None,
session: Optional[AsyncClientSession] = None,
Expand Down Expand Up @@ -779,17 +780,12 @@ async def bulk_write(

.. versionadded:: 3.0
"""
common.validate_list("requests", requests)
common.validate_list_or_generator("requests", requests)

blk = _AsyncBulk(self, ordered, bypass_document_validation, comment=comment, let=let)
for request in requests:
try:
request._add_to_bulk(blk)
except AttributeError:
raise TypeError(f"{request!r} is not a valid request") from None

write_concern = self._write_concern_for(session)
bulk_api_result = await blk.execute(write_concern, session, _Op.INSERT)
bulk_api_result = await blk.execute(requests, write_concern, session, _Op.INSERT)
if bulk_api_result is not None:
return BulkWriteResult(bulk_api_result, True)
return BulkWriteResult({}, False)
Expand Down
8 changes: 8 additions & 0 deletions pymongo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
TYPE_CHECKING,
Any,
Callable,
Generator,
Iterator,
Mapping,
MutableMapping,
Expand Down Expand Up @@ -530,6 +531,13 @@ def validate_list(option: str, value: Any) -> list:
return value


def validate_list_or_generator(option: str, value: Any) -> Union[list, Generator]:
"""Validates that 'value' is a list or generator."""
if isinstance(value, Generator):
return value
return validate_list(option, value)


def validate_list_or_none(option: Any, value: Any) -> Optional[list]:
"""Validates that 'value' is a list or None."""
if value is None:
Expand Down
37 changes: 27 additions & 10 deletions pymongo/synchronous/bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from typing import (
TYPE_CHECKING,
Any,
Generator,
Iterator,
Mapping,
Optional,
Expand Down Expand Up @@ -72,7 +73,7 @@
from pymongo.write_concern import WriteConcern

if TYPE_CHECKING:
from pymongo.synchronous.collection import Collection
from pymongo.synchronous.collection import Collection, _WriteOp
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.synchronous.pool import Connection
from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline
Expand Down Expand Up @@ -214,28 +215,45 @@ def add_delete(
self.is_retryable = False
self.ops.append((_DELETE, cmd))

def gen_ordered(self) -> Iterator[Optional[_Run]]:
def gen_ordered(self, requests) -> Iterator[Optional[_Run]]:
"""Generate batches of operations, batched by type of
operation, in the order **provided**.
"""
run = None
for idx, (op_type, operation) in enumerate(self.ops):
for idx, request in enumerate(requests):
try:
request._add_to_bulk(self)
except AttributeError:
raise TypeError(f"{request!r} is not a valid request") from None
(op_type, operation) = self.ops[idx]
if run is None:
run = _Run(op_type)
elif run.op_type != op_type:
yield run
run = _Run(op_type)
run.add(idx, operation)
if run is None:
raise InvalidOperation("No operations to execute")
yield run

def gen_unordered(self) -> Iterator[_Run]:
def gen_unordered(self, requests) -> Iterator[_Run]:
"""Generate batches of operations, batched by type of
operation, in arbitrary order.
"""
operations = [_Run(_INSERT), _Run(_UPDATE), _Run(_DELETE)]
for idx, (op_type, operation) in enumerate(self.ops):
for idx, request in enumerate(requests):
try:
request._add_to_bulk(self)
except AttributeError:
raise TypeError(f"{request!r} is not a valid request") from None
(op_type, operation) = self.ops[idx]
operations[op_type].add(idx, operation)

if (
len(operations[_INSERT].ops) == 0
and len(operations[_UPDATE].ops) == 0
and len(operations[_DELETE].ops) == 0
):
raise InvalidOperation("No operations to execute")
for run in operations:
if run.ops:
yield run
Expand Down Expand Up @@ -724,23 +742,22 @@ def execute_no_results(

def execute(
self,
generator: Generator[_WriteOp[_DocumentType]],
write_concern: WriteConcern,
session: Optional[ClientSession],
operation: str,
) -> Any:
"""Execute operations."""
if not self.ops:
raise InvalidOperation("No operations to execute")
if self.executed:
raise InvalidOperation("Bulk operations can only be executed once.")
self.executed = True
write_concern = write_concern or self.collection.write_concern
session = _validate_session_write_concern(session, write_concern)

if self.ordered:
generator = self.gen_ordered()
generator = self.gen_ordered(generator)
else:
generator = self.gen_unordered()
generator = self.gen_unordered(generator)

client = self.collection.database.client
if not write_concern.acknowledged:
Expand Down
12 changes: 4 additions & 8 deletions pymongo/synchronous/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Any,
Callable,
ContextManager,
Generator,
Generic,
Iterable,
Iterator,
Expand Down Expand Up @@ -698,7 +699,7 @@ def _create(
@_csot.apply
def bulk_write(
self,
requests: Sequence[_WriteOp[_DocumentType]],
requests: Sequence[_WriteOp[_DocumentType]] | Generator[_WriteOp[_DocumentType]],
ordered: bool = True,
bypass_document_validation: Optional[bool] = None,
session: Optional[ClientSession] = None,
Expand Down Expand Up @@ -778,17 +779,12 @@ def bulk_write(

.. versionadded:: 3.0
"""
common.validate_list("requests", requests)
common.validate_list_or_generator("requests", requests)

blk = _Bulk(self, ordered, bypass_document_validation, comment=comment, let=let)
for request in requests:
try:
request._add_to_bulk(blk)
except AttributeError:
raise TypeError(f"{request!r} is not a valid request") from None

write_concern = self._write_concern_for(session)
bulk_api_result = blk.execute(write_concern, session, _Op.INSERT)
bulk_api_result = blk.execute(requests, write_concern, session, _Op.INSERT)
if bulk_api_result is not None:
return BulkWriteResult(bulk_api_result, True)
return BulkWriteResult({}, False)
Expand Down
15 changes: 15 additions & 0 deletions test/asynchronous/test_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,21 @@ async def test_numerous_inserts(self):
self.assertEqual(n_docs, result.inserted_count)
self.assertEqual(n_docs, await self.coll.count_documents({}))

async def test_numerous_inserts_generator(self):
# Ensure we don't exceed server's maxWriteBatchSize size limit.
n_docs = await async_client_context.max_write_batch_size + 100
requests = (InsertOne[dict]({}) for _ in range(n_docs))
result = await self.coll.bulk_write(requests, ordered=False)
self.assertEqual(n_docs, result.inserted_count)
self.assertEqual(n_docs, await self.coll.count_documents({}))

# Same with ordered bulk.
await self.coll.drop()
requests = (InsertOne[dict]({}) for _ in range(n_docs))
result = await self.coll.bulk_write(requests)
self.assertEqual(n_docs, result.inserted_count)
self.assertEqual(n_docs, await self.coll.count_documents({}))

async def test_bulk_max_message_size(self):
await self.coll.delete_many({})
self.addAsyncCleanup(self.coll.delete_many, {})
Expand Down
15 changes: 15 additions & 0 deletions test/test_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,21 @@ def test_numerous_inserts(self):
self.assertEqual(n_docs, result.inserted_count)
self.assertEqual(n_docs, self.coll.count_documents({}))

def test_numerous_inserts_generator(self):
# Ensure we don't exceed server's maxWriteBatchSize size limit.
n_docs = client_context.max_write_batch_size + 100
requests = (InsertOne[dict]({}) for _ in range(n_docs))
result = self.coll.bulk_write(requests, ordered=False)
self.assertEqual(n_docs, result.inserted_count)
self.assertEqual(n_docs, self.coll.count_documents({}))

# Same with ordered bulk.
self.coll.drop()
requests = (InsertOne[dict]({}) for _ in range(n_docs))
result = self.coll.bulk_write(requests)
self.assertEqual(n_docs, result.inserted_count)
self.assertEqual(n_docs, self.coll.count_documents({}))

def test_bulk_max_message_size(self):
self.coll.delete_many({})
self.addCleanup(self.coll.delete_many, {})
Expand Down
Loading