diff --git a/docs/source/schemas.rst b/docs/source/schemas.rst index 01fdd34..8abc4a9 100644 --- a/docs/source/schemas.rst +++ b/docs/source/schemas.rst @@ -141,7 +141,7 @@ Deleting a schema Any schema can be dropped, including ones not created by :class:`~psqlextra.schema.PostgresSchema`. -The ``public`` schema cannot be dropped. This is a Postgres built-in and it is almost always a mistake to drop it. A :class:`~django.core.exceptions.SuspiciousOperation` erorr will be raised if you attempt to drop the ``public`` schema. +The ``public`` schema cannot be dropped. This is a Postgres built-in and it is almost always a mistake to drop it. A :class:`~django.core.exceptions.SuspiciousOperation` error will be raised if you attempt to drop the ``public`` schema. .. warning:: diff --git a/psqlextra/compiler.py b/psqlextra/compiler.py index 88a65e9..bc529d7 100644 --- a/psqlextra/compiler.py +++ b/psqlextra/compiler.py @@ -3,22 +3,22 @@ import sys from collections.abc import Iterable -from typing import Tuple, Union +from typing import Optional, Tuple, Union import django from django.conf import settings from django.core.exceptions import SuspiciousOperation -from django.db.models import Expression, Model, Q +from django.db.models import Expression, Field, Model, Q from django.db.models.fields.related import RelatedField from django.db.models.sql import compiler as django_compiler from django.db.utils import ProgrammingError from .expressions import HStoreValue -from .types import ConflictAction +from .types import ConflictAction, UpsertOperation -def append_caller_to_sql(sql): +def append_caller_to_sql(sql) -> str: """Append the caller to SQL queries. Adds the calling file and function as an SQL comment to each query. @@ -162,26 +162,39 @@ def as_sql(self, *args, **kwargs): class PostgresInsertOnConflictCompiler(django_compiler.SQLInsertCompiler): # type: ignore [name-defined] """Compiler for SQL INSERT statements.""" + RETURNING_OPERATION_TYPE_CLAUSE = ( + f"CASE WHEN xmax::text::int > 0 " + f"THEN '{UpsertOperation.UPDATE.value}' " + f"ELSE '{UpsertOperation.INSERT.value}' END" + ) + RETURNING_OPERATION_TYPE_FIELD = "_operation_type" + def __init__(self, *args, **kwargs): """Initializes a new instance of :see:PostgresInsertOnConflictCompiler.""" super().__init__(*args, **kwargs) self.qn = self.connection.ops.quote_name - def as_sql(self, return_id=False, *args, **kwargs): + def as_sql( + self, + return_id=False, + return_operation_type=False, + *args, + **kwargs, + ): """Builds the SQL INSERT statement.""" queries = [ - self._rewrite_insert(sql, params, return_id) + self._rewrite_insert(sql, params, return_id, return_operation_type) for sql, params in super().as_sql(*args, **kwargs) ] return queries - def execute_sql(self, return_id=False): + def execute_sql(self, return_id=False, return_operation_type=False): # execute all the generate queries with self.connection.cursor() as cursor: rows = [] - for sql, params in self.as_sql(return_id): + for sql, params in self.as_sql(return_id, return_operation_type): cursor.execute(sql, params) try: rows.extend(cursor.fetchall()) @@ -199,7 +212,9 @@ def execute_sql(self, return_id=False): for row in rows ] - def _rewrite_insert(self, sql, params, return_id=False): + def _rewrite_insert( + self, sql, params, return_id=False, return_operation_type=False + ): """Rewrites a formed SQL INSERT query to include the ON CONFLICT clause. @@ -221,16 +236,27 @@ def _rewrite_insert(self, sql, params, return_id=False): returning = ( self.qn(self.query.model._meta.pk.attname) if return_id else "*" ) + # Return metadata about the row, so we can tell if it was inserted or + # updated by checking the `xmax` Postgres system column. + if return_operation_type: + returning += f", ({self.RETURNING_OPERATION_TYPE_CLAUSE}) AS {self.RETURNING_OPERATION_TYPE_FIELD}" (sql, params) = self._rewrite_insert_on_conflict( - sql, params, self.query.conflict_action.value, returning + sql, + params, + self.query.conflict_action.value, + returning, ) return append_caller_to_sql(sql), params def _rewrite_insert_on_conflict( - self, sql, params, conflict_action: ConflictAction, returning - ): + self, + sql: str, + params: list, + conflict_action: ConflictAction, + returning: str, + ) -> Tuple[str, list]: """Rewrites a normal SQL INSERT query to add the 'ON CONFLICT' clause.""" @@ -256,7 +282,7 @@ def _rewrite_insert_on_conflict( rewritten_sql += f" DO {conflict_action}" - if conflict_action == "UPDATE": + if conflict_action == ConflictAction.UPDATE.value: rewritten_sql += f" SET {update_columns}" if update_condition: @@ -353,7 +379,7 @@ def _build_conflict_target_by_index(self): stmt = matching_index.create_sql(self.query.model, schema_editor) return "(%s)" % stmt.parts["columns"] - def _get_model_field(self, name: str): + def _get_model_field(self, name: str) -> Optional[Field]: """Gets the field on a model with the specified name. Arguments: @@ -384,7 +410,7 @@ def _get_model_field(self, name: str): return None - def _format_field_name(self, field_name) -> str: + def _format_field_name(self, field_name): """Formats a field's name for usage in SQL. Arguments: @@ -399,7 +425,7 @@ def _format_field_name(self, field_name) -> str: field = self._get_model_field(field_name) return self.qn(field.column) - def _format_field_value(self, field_name) -> str: + def _format_field_value(self, field_name): """Formats a field's value for usage in SQL. Arguments: @@ -432,7 +458,8 @@ def _format_field_value(self, field_name) -> str: ) def _compile_expression( - self, expression: Union[Expression, Q, str] + self, + expression: Union[Expression, Q, str], ) -> Tuple[str, Union[tuple, list]]: """Compiles an expression, Q object or raw SQL string into SQL and tuple of parameters.""" @@ -452,7 +479,7 @@ def _compile_expression( return expression, tuple() - def _assert_valid_field(self, field_name: str): + def _assert_valid_field(self, field_name: str) -> None: """Asserts that a field with the specified name exists on the model and raises :see:SuspiciousOperation if it does not.""" diff --git a/psqlextra/query.py b/psqlextra/query.py index b3feec1..aeff8c0 100644 --- a/psqlextra/query.py +++ b/psqlextra/query.py @@ -139,6 +139,7 @@ def bulk_insert( rows: Iterable[dict], return_model: bool = False, using: Optional[str] = None, + return_operation_type: bool = False, ): """Creates multiple new records in the database. @@ -158,6 +159,13 @@ def bulk_insert( Optional name of the database connection to use for this query. + return_operation_type (default: False): + If the operation type should be returned for each row. + This is only supported when return_model is False. + The operation_type is either 'INSERT' or 'UPDATE' and + the value will be contained in the '_operation_type' key + of the returned dict. + Returns: A list of either the dicts of the rows inserted, including the pk or the models of the rows inserted with defaults for any fields not specified @@ -195,7 +203,10 @@ def is_empty(r): deduped_rows.append(row) compiler = self._build_insert_compiler(deduped_rows, using=using) - objs = compiler.execute_sql(return_id=not return_model) + objs = compiler.execute_sql( + return_id=not return_model, + return_operation_type=return_operation_type and not return_model, + ) if return_model: return [ self._create_model_instance(dict(row, **obj), compiler.using) @@ -261,7 +272,9 @@ def insert_and_get(self, using: Optional[str] = None, **fields): return super().create(**fields) compiler = self._build_insert_compiler([fields], using=using) - rows = compiler.execute_sql(return_id=False) + rows = compiler.execute_sql( + return_id=False, return_operation_type=False + ) if not rows: return None @@ -293,7 +306,7 @@ def upsert( index_predicate: Optional[Union[Expression, Q, str]] = None, using: Optional[str] = None, update_condition: Optional[Union[Expression, Q, str]] = None, - ) -> int: + ) -> Optional[int]: """Creates a new record or updates the existing one with the specified data. @@ -336,7 +349,7 @@ def upsert_and_get( index_predicate: Optional[Union[Expression, Q, str]] = None, using: Optional[str] = None, update_condition: Optional[Union[Expression, Q, str]] = None, - ): + ) -> Optional[TModel]: """Creates a new record or updates the existing one with the specified data and then gets the row. @@ -381,6 +394,7 @@ def bulk_upsert( return_model: bool = False, using: Optional[str] = None, update_condition: Optional[Union[Expression, Q, str]] = None, + return_operation_type: bool = False, ): """Creates a set of new records or updates the existing ones with the specified data. @@ -407,6 +421,13 @@ def bulk_upsert( update_condition: Only update if this SQL expression evaluates to true. + return_operation_type (default: False): + If the operation type should be returned for each row. + This is only supported when return_model is False. + The operation_type is either 'INSERT' or 'UPDATE' and + the value will be contained in the '_operation_type' key + of the returned dict. + Returns: A list of either the dicts of the rows upserted, including the pk or the models of the rows upserted @@ -418,15 +439,20 @@ def bulk_upsert( index_predicate=index_predicate, update_condition=update_condition, ) - return self.bulk_insert(rows, return_model, using=using) + return self.bulk_insert( + rows, + return_model, + using=using, + return_operation_type=return_operation_type, + ) def _create_model_instance( self, field_values: dict, using: str, apply_converters: bool = True ): """Creates a new instance of the model with the specified field. - Use this after the row was inserted into the database. The new - instance will marked as "saved". + Use this after the row was inserted/updated into the database. + The new instance will be marked as "saved". """ converted_field_values = field_values.copy() diff --git a/psqlextra/types.py b/psqlextra/types.py index a325fd9..1c007ba 100644 --- a/psqlextra/types.py +++ b/psqlextra/types.py @@ -29,6 +29,13 @@ def all(cls) -> List["ConflictAction"]: return [choice for choice in cls] +class UpsertOperation(StrEnum): + """Possible operations to take on an upsert.""" + + INSERT = "INSERT" + UPDATE = "UPDATE" + + class PostgresPartitioningMethod(StrEnum): """Methods of partitioning supported by PostgreSQL 11.x native support for table partitioning.""" diff --git a/tests/test_on_conflict.py b/tests/test_on_conflict.py index 02eda62..4b876ea 100644 --- a/tests/test_on_conflict.py +++ b/tests/test_on_conflict.py @@ -9,6 +9,7 @@ from psqlextra.fields import HStoreField from psqlextra.models import PostgresModel from psqlextra.query import ConflictAction +from psqlextra.types import UpsertOperation from .fake_model import get_fake_model @@ -397,6 +398,37 @@ def test_bulk_return(): assert obj["id"] == index +def test_bulk_return_with_operation_type(): + """Tests if the _operation_type is properly returned from 'bulk_insert'.""" + + model = get_fake_model( + { + "id": models.BigAutoField(primary_key=True), + "name": models.CharField(max_length=255, unique=True), + } + ) + + rows = [dict(name="John Smith"), dict(name="Jane Doe")] + + objs = model.objects.on_conflict( + ["name"], ConflictAction.UPDATE + ).bulk_insert(rows, return_operation_type=True) + + for index, obj in enumerate(objs, 1): + assert obj["id"] == index + assert obj["_operation_type"] == UpsertOperation.INSERT.value + + # Add objects again, update should return the same ids + # as we're just updating. + objs = model.objects.on_conflict( + ["name"], ConflictAction.UPDATE + ).bulk_insert(rows, return_operation_type=True) + + for index, obj in enumerate(objs, 1): + assert obj["id"] == index + assert obj["_operation_type"] == UpsertOperation.UPDATE.value + + @pytest.mark.parametrize("conflict_action", ConflictAction.all()) def test_bulk_return_models(conflict_action): """Tests whether models are returned instead of dictionaries when diff --git a/tests/test_upsert.py b/tests/test_upsert.py index b9176da..2ba89e4 100644 --- a/tests/test_upsert.py +++ b/tests/test_upsert.py @@ -8,6 +8,7 @@ from psqlextra.expressions import ExcludedCol from psqlextra.fields import HStoreField from psqlextra.query import ConflictAction +from psqlextra.types import UpsertOperation from .fake_model import get_fake_model @@ -259,6 +260,44 @@ def test_upsert_bulk_no_rows(): ) +def test_upsert_bulk_returns_operation_type(): + """Tests whether bulk_upsert works properly with the return_operation_type + flag.""" + + model = get_fake_model( + { + "first_name": models.CharField( + max_length=255, null=True, unique=True + ), + "last_name": models.CharField(max_length=255, null=True), + } + ) + + rows = model.objects.bulk_upsert( + conflict_target=["first_name"], + rows=[ + dict(first_name="Swen", last_name="Kooij"), + dict(first_name="Henk", last_name="Test"), + ], + return_operation_type=True, + ) + + for row in rows: + assert row["_operation_type"] == UpsertOperation.INSERT.value + + rows = model.objects.bulk_upsert( + conflict_target=["first_name"], + rows=[ + dict(first_name="Swen", last_name="Test"), + dict(first_name="Henk", last_name="Kooij"), + ], + return_operation_type=True, + ) + + for row in rows: + assert row["_operation_type"] == UpsertOperation.UPDATE.value + + def test_bulk_upsert_return_models(): """Tests whether models are returned instead of dictionaries when specifying the return_model=True argument."""