Skip to content

INTPYTHON-355 Add transaction support #317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 2 additions & 23 deletions .evergreen/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,45 +68,24 @@ tasks:
- func: "run unit tests"

buildvariants:
- name: tests-6-noauth-nossl
display_name: Run Tests 6.0 NoAuth NoSSL
run_on: rhel87-small
expansions:
MONGODB_VERSION: "6.0"
TOPOLOGY: server
AUTH: "noauth"
SSL: "nossl"
tasks:
- name: run-tests

- name: tests-6-auth-ssl
display_name: Run Tests 6.0 Auth SSL
run_on: rhel87-small
expansions:
MONGODB_VERSION: "6.0"
TOPOLOGY: server
TOPOLOGY: sharded_cluster
AUTH: "auth"
SSL: "ssl"
tasks:
- name: run-tests

- name: tests-8-noauth-nossl
display_name: Run Tests 8.0 NoAuth NoSSL
run_on: rhel87-small
expansions:
MONGODB_VERSION: "8.0"
TOPOLOGY: server
AUTH: "noauth"
SSL: "nossl"
tasks:
- name: run-tests

- name: tests-8-auth-ssl
display_name: Run Tests 8.0 Auth SSL
run_on: rhel87-small
expansions:
MONGODB_VERSION: "8.0"
TOPOLOGY: server
TOPOLOGY: sharded_cluster
AUTH: "auth"
SSL: "ssl"
tasks:
Expand Down
2 changes: 1 addition & 1 deletion .evergreen/run-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ cp ./.github/workflows/mongodb_settings.py django_repo/tests/
cp ./.github/workflows/runtests.py django_repo/tests/runtests_.py

# Run tests
python django_repo/tests/runtests_.py
python django_repo/tests/runtests.py --settings mongodb_settings -v 2
2 changes: 1 addition & 1 deletion .github/workflows/test-python-atlas.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,4 @@ jobs:
working-directory: .
run: bash .github/workflows/start_local_atlas.sh mongodb/mongodb-atlas-local:7
- name: Run tests
run: python3 django_repo/tests/runtests_.py
run: python3 django_repo/tests/runtests.py --settings mongodb_settings -v 2
67 changes: 63 additions & 4 deletions django_mongodb_backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import os

from django.core.exceptions import ImproperlyConfigured
from django.db import DEFAULT_DB_ALIAS
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.backends.utils import debug_transaction
from django.utils.asyncio import async_unsafe
from django.utils.functional import cached_property
from pymongo.collection import Collection
Expand Down Expand Up @@ -32,6 +34,17 @@ def __exit__(self, exception_type, exception_value, exception_traceback):
pass


def requires_transaction_support(func):
"""Make a method a no-op if transactions aren't supported."""

def wrapper(self, *args, **kwargs):
if not self.features.supports_transactions:
return
func(self, *args, **kwargs)

return wrapper


class DatabaseWrapper(BaseDatabaseWrapper):
data_types = {
"AutoField": "int",
Expand Down Expand Up @@ -140,6 +153,10 @@ def _isnull_operator(a, b):
ops_class = DatabaseOperations
validation_class = DatabaseValidation

def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS):
super().__init__(settings_dict, alias=alias)
self.session = None

def get_collection(self, name, **kwargs):
collection = Collection(self.database, name, **kwargs)
if self.queries_logged:
Expand Down Expand Up @@ -189,14 +206,48 @@ def _driver_info(self):
return DriverInfo("django-mongodb-backend", django_mongodb_backend_version)
return None

@requires_transaction_support
def _commit(self):
pass
if self.session:
with debug_transaction(self, "session.commit_transaction()"):
self.session.commit_transaction()
self._end_session()

@requires_transaction_support
def _rollback(self):
pass
if self.session:
with debug_transaction(self, "session.abort_transaction()"):
self.session.abort_transaction()
self._end_session()

def _start_transaction(self):
# Private API, specific to this backend.
if self.session is None:
self.session = self.connection.start_session()
with debug_transaction(self, "session.start_transaction()"):
self.session.start_transaction()

def set_autocommit(self, autocommit, force_begin_transaction_with_broken_autocommit=False):
self.autocommit = autocommit
def _end_session(self):
# Private API, specific to this backend.
self.session.end_session()
self.session = None

@requires_transaction_support
def _start_transaction_under_autocommit(self):
# Implementing this hook (intended only for SQLite), allows
# BaseDatabaseWrapper.set_autocommit() to use it to start a transaction
# rather than set_autocommit(), bypassing set_autocommit()'s call to
# debug_transaction(self, "BEGIN") which isn't semantic for a no-SQL
# backend.
self._start_transaction()

@requires_transaction_support
def _set_autocommit(self, autocommit, force_begin_transaction_with_broken_autocommit=False):
# Besides @transaction.atomic() (which uses
# _start_transaction_under_autocommit(), disabling autocommit is
# another way to start a transaction.
if not autocommit:
self._start_transaction()

def _close(self):
# Normally called by close(), this method is also called by some tests.
Expand All @@ -210,6 +261,10 @@ def close(self):

def close_pool(self):
"""Close the MongoClient."""
# Clear commit hooks and session.
self.run_on_commit = []
if self.session:
self._end_session()
connection = self.connection
if connection is None:
return
Expand All @@ -225,6 +280,10 @@ def close_pool(self):
def cursor(self):
return Cursor()

@requires_transaction_support
def validate_no_broken_transaction(self):
super().validate_no_broken_transaction()

def get_database_version(self):
"""Return a tuple of the database's version."""
return tuple(self.connection.server_info()["versionArray"])
10 changes: 8 additions & 2 deletions django_mongodb_backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,10 @@ def execute_sql(self, returning_fields=None):
@wrap_database_errors
def insert(self, docs, returning_fields=None):
"""Store a list of documents using field columns as element names."""
inserted_ids = self.collection.insert_many(docs).inserted_ids
self.connection.validate_no_broken_transaction()
inserted_ids = self.collection.insert_many(
docs, session=self.connection.session
).inserted_ids
return [(x,) for x in inserted_ids] if returning_fields else []

@cached_property
Expand Down Expand Up @@ -768,7 +771,10 @@ def execute_sql(self, result_type):

@wrap_database_errors
def update(self, criteria, pipeline):
return self.collection.update_many(criteria, pipeline).matched_count
self.connection.validate_no_broken_transaction()
return self.collection.update_many(
criteria, pipeline, session=self.connection.session
).matched_count

def check_query(self):
super().check_query()
Expand Down
77 changes: 62 additions & 15 deletions django_mongodb_backend/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_temporal_subtraction = True
# MongoDB stores datetimes in UTC.
supports_timezones = False
# Not implemented: https://github.com/mongodb/django-mongodb-backend/issues/7
supports_transactions = False
supports_unspecified_pk = True
uses_savepoints = False

Expand All @@ -50,8 +48,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"aggregation.tests.AggregateTestCase.test_order_by_aggregate_transform",
# 'NulledTransform' object has no attribute 'as_mql'.
"lookup.tests.LookupTests.test_exact_none_transform",
# "Save with update_fields did not affect any rows."
"basic.tests.SelectOnSaveTests.test_select_on_save_lying_update",
# BaseExpression.convert_value() crashes with Decimal128.
"aggregation.tests.AggregateTestCase.test_combine_different_types",
"annotations.tests.NonAggregateAnnotationTestCase.test_combined_f_expression_annotation_with_aggregation",
Expand Down Expand Up @@ -96,13 +92,47 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor_right_null",
"expressions.tests.ExpressionOperatorTests.test_lefthand_transformed_field_bitwise_or",
}
_django_test_expected_failures_no_transactions = {
# "Save with update_fields did not affect any rows." instead of
# "An error occurred in the current transaction. You can't execute
# queries until the end of the 'atomic' block."
"basic.tests.SelectOnSaveTests.test_select_on_save_lying_update",
}
_django_test_expected_failures_transactions = {
# When update_or_create() fails with IntegrityError, the transaction
# is no longer usable.
"get_or_create.tests.UpdateOrCreateTests.test_manual_primary_key_test",
"get_or_create.tests.UpdateOrCreateTestsWithManualPKs.test_create_with_duplicate_primary_key",
# Tests that require savepoints
"admin_views.tests.AdminViewBasicTest.test_disallowed_to_field",
"admin_views.tests.AdminViewPermissionsTest.test_add_view",
"admin_views.tests.AdminViewPermissionsTest.test_change_view",
"admin_views.tests.AdminViewPermissionsTest.test_change_view_save_as_new",
"admin_views.tests.AdminViewPermissionsTest.test_delete_view",
"auth_tests.test_views.ChangelistTests.test_view_user_password_is_readonly",
"fixtures.tests.FixtureLoadingTests.test_loaddata_app_option",
"fixtures.tests.FixtureLoadingTests.test_unmatched_identifier_loading",
"fixtures_model_package.tests.FixtureTestCase.test_loaddata",
"get_or_create.tests.GetOrCreateTests.test_get_or_create_invalid_params",
"get_or_create.tests.UpdateOrCreateTests.test_integrity",
"many_to_many.tests.ManyToManyTests.test_add",
"many_to_one.tests.ManyToOneTests.test_fk_assignment_and_related_object_cache",
"model_fields.test_booleanfield.BooleanFieldTests.test_null_default",
"model_fields.test_floatfield.TestFloatField.test_float_validates_object",
"multiple_database.tests.QueryTestCase.test_generic_key_cross_database_protection",
"multiple_database.tests.QueryTestCase.test_m2m_cross_database_protection",
}

@cached_property
def django_test_expected_failures(self):
expected_failures = super().django_test_expected_failures
expected_failures.update(self._django_test_expected_failures)
if not self.is_mongodb_6_3:
expected_failures.update(self._django_test_expected_failures_bitwise)
if self.supports_transactions:
expected_failures.update(self._django_test_expected_failures_transactions)
else:
expected_failures.update(self._django_test_expected_failures_no_transactions)
return expected_failures

django_test_skips = {
Expand Down Expand Up @@ -485,16 +515,6 @@ def django_test_expected_failures(self):
"Connection health checks not implemented.": {
"backends.base.test_base.ConnectionHealthChecksTests",
},
"transaction.atomic() is not supported.": {
"backends.base.test_base.DatabaseWrapperLoggingTests",
"migrations.test_executor.ExecutorTests.test_atomic_operation_in_non_atomic_migration",
"migrations.test_operations.OperationTests.test_run_python_atomic",
},
"transaction.rollback() is not supported.": {
"transactions.tests.AtomicMiscTests.test_mark_for_rollback_on_error_in_autocommit",
"transactions.tests.AtomicMiscTests.test_mark_for_rollback_on_error_in_transaction",
"transactions.tests.NonAutocommitTests.test_orm_query_after_error_and_rollback",
},
"migrate --fake-initial is not supported.": {
"migrations.test_commands.MigrateTests.test_migrate_fake_initial",
"migrations.test_commands.MigrateTests.test_migrate_fake_split_initial",
Expand Down Expand Up @@ -533,8 +553,18 @@ def django_test_expected_failures(self):
"foreign_object.test_tuple_lookups.TupleLookupsTests",
},
"ColPairs is not supported.": {
# 'ColPairs' object has no attribute 'as_mql'
"auth_tests.test_views.CustomUserCompositePrimaryKeyPasswordResetTest",
"composite_pk.test_aggregate.CompositePKAggregateTests",
"composite_pk.test_create.CompositePKCreateTests",
"composite_pk.test_delete.CompositePKDeleteTests",
"composite_pk.test_filter.CompositePKFilterTests",
"composite_pk.test_get.CompositePKGetTests",
"composite_pk.test_models.CompositePKModelsTests",
"composite_pk.test_order_by.CompositePKOrderByTests",
"composite_pk.test_update.CompositePKUpdateTests",
"composite_pk.test_values.CompositePKValuesTests",
"composite_pk.tests.CompositePKTests",
"composite_pk.tests.CompositePKFixturesTests",
},
"Custom lookups are not supported.": {
"custom_lookups.tests.BilateralTransformTests",
Expand Down Expand Up @@ -577,3 +607,20 @@ def supports_atlas_search(self):
return False
else:
return True

@cached_property
def supports_select_union(self):
# Stage not supported inside of a multi-document transaction: $unionWith
return not self.supports_transactions

@cached_property
def supports_transactions(self):
"""
Transactions are enabled if MongoDB is configured as a replica set or a
sharded cluster.
"""
self.connection.ensure_connection()
client = self.connection.connection.admin
hello = client.command("hello")
# a replica set or a sharded cluster
return "setName" in hello or hello.get("msg") == "isdbgrid"
10 changes: 8 additions & 2 deletions django_mongodb_backend/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,23 @@ def __repr__(self):
@wrap_database_errors
def delete(self):
"""Execute a delete query."""
self.compiler.connection.validate_no_broken_transaction()
if self.compiler.subqueries:
raise NotSupportedError("Cannot use QuerySet.delete() when a subquery is required.")
return self.compiler.collection.delete_many(self.match_mql).deleted_count
return self.compiler.collection.delete_many(
self.match_mql, session=self.compiler.connection.session
).deleted_count

@wrap_database_errors
def get_cursor(self):
"""
Return a pymongo CommandCursor that can be iterated on to give the
results of the query.
"""
return self.compiler.collection.aggregate(self.get_pipeline())
self.compiler.connection.validate_no_broken_transaction()
return self.compiler.collection.aggregate(
self.get_pipeline(), session=self.compiler.connection.session
)

def get_pipeline(self):
pipeline = []
Expand Down
2 changes: 1 addition & 1 deletion django_mongodb_backend/queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, pipeline, using, model):
def _execute_query(self):
connection = connections[self.using]
collection = connection.get_collection(self.model._meta.db_table)
self.cursor = collection.aggregate(self.pipeline)
self.cursor = collection.aggregate(self.pipeline, session=connection.session)

def __str__(self):
return str(self.pipeline)
Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
"pymongo": ("https://pymongo.readthedocs.io/en/stable/", None),
"python": ("https://docs.python.org/3/", None),
"atlas": ("https://www.mongodb.com/docs/atlas/", None),
"manual": ("https://www.mongodb.com/docs/manual/", None),
}

root_doc = "contents"
Expand Down
Loading