diff --git a/docs/index.rst b/docs/index.rst index 125b24b..4f2387d 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -272,7 +272,17 @@ Keyword Argument Description parameters. ``ignore_conflicts`` Specify True to ignore unique constraint or exclusion - constraint violation errors. The default is False. + constraint violation errors. The default is False. This + is depreciated in favor of `on_conflict={'action': 'ignore'}`. + +``on_conflict`` Specifies how PostgreSQL handles conflicts. For example, + `on_conflict={'action': 'ignore'}` will ignore any + conflicts. If setting `'action'` to `'update'`, you + must also specify `'target'` (the source of the + constraint: either a model field name, a constraint name, + or a list of model field names) as well as `'columns'` + (a list of model fields to update). The default is None, + which will raise conflict errors if they occur. ``using`` Sets the database to use when importing data. Default is None, which will use the ``'default'`` @@ -510,6 +520,41 @@ Now you can run that subclass directly rather than via a manager. The only diffe # Then save it. c.save() +For example, if you wish to return a QuerySet of the models imported into the database, you could do the following: + +.. code-block:: python + + + from django.db import models + from postgres_copy import CopyMapping, CopyQuerySet + + + class ResultsCopyMapping(CopyMapping): + def insert_suffix(self) -> str: + """Add `RETURNING` sql clause to get newly created/updated ids.""" + suffix = super().insert_suffix() + suffix = suffix.split(';')[0] + ' RETURNING id;' + return suffix + + def post_insert(self, cursor) -> None: + """Extend to store results from `RETURNING` clause.""" + self.obj_ids = [r[0] for r in cursor.fetchall()] + + + class ResultsCopyQuerySet(CopyQuerySet): + def from_csv(self, csv_path_or_obj, mapping=None, **kwargs): + mapping = ResultsCopyMapping(self.model, csv_path_or_obj, mapping=None, **kwargs) + count = mapping.save(silent=True) + objs = self.model.objects.filter(id__in=mapping.obj_ids) + return objs, count + + + class Person(models.Model): + name = models.CharField(max_length=500) + number = models.IntegerField() + source_csv = models.CharField(max_length=500) + objects = ResultsCopyQuerySet.as_manager() + Export options ============== diff --git a/postgres_copy/copy_from.py b/postgres_copy/copy_from.py index 6026ddf..7e251d6 100644 --- a/postgres_copy/copy_from.py +++ b/postgres_copy/copy_from.py @@ -9,6 +9,7 @@ import logging from collections import OrderedDict from io import TextIOWrapper +import warnings from django.db import NotSupportedError from django.db import connections, router from django.core.exceptions import FieldDoesNotExist @@ -33,6 +34,7 @@ def __init__( force_null=None, encoding=None, ignore_conflicts=False, + on_conflict={}, static_mapping=None, temp_table_name=None ): @@ -57,8 +59,9 @@ def __init__( self.force_not_null = force_not_null self.force_null = force_null self.encoding = encoding - self.supports_ignore_conflicts = True + self.supports_on_conflict = True self.ignore_conflicts = ignore_conflicts + self.on_conflict = on_conflict if static_mapping is not None: self.static_mapping = OrderedDict(static_mapping) else: @@ -76,10 +79,18 @@ def __init__( if self.conn.vendor != 'postgresql': raise TypeError("Only PostgreSQL backends supported") - # Check if it is PSQL 9.5 or greater, which determines if ignore_conflicts is supported - self.supports_ignore_conflicts = self.is_postgresql_9_5() - if self.ignore_conflicts and not self.supports_ignore_conflicts: - raise NotSupportedError('This database backend does not support ignoring conflicts.') + # Check if it is PSQL 9.5 or greater, which determines if on_conflict is supported + self.supports_on_conflict = self.is_postgresql_9_5() + if self.ignore_conflicts: + self.on_conflict = { + 'action': 'ignore', + } + warnings.warn( + "The `ignore_conflicts` kwarg has been replaced with " + "on_conflict={'action': 'ignore'}." + ) + if self.on_conflict and not self.supports_on_conflict: + raise NotSupportedError('This database backend does not support conflict logic.') # Pull the CSV headers self.headers = self.get_headers() @@ -317,10 +328,50 @@ def insert_suffix(self): """ Preps the suffix to the insert query. """ - if self.ignore_conflicts: + if self.on_conflict: + try: + action = self.on_conflict['action'] + except KeyError: + raise ValueError("Must specify an `action` when passing `on_conflict`.") + if action == 'ignore': + target, action = "", "DO NOTHING" + elif action == 'update': + try: + target = self.on_conflict['target'] + except KeyError: + raise ValueError("Must specify `target` when action == 'update'.") + try: + columns = self.on_conflict['columns'] + except KeyError: + raise ValueError("Must specify `columns` when action == 'update'.") + + # As recommended in PostgreSQL's INSERT documentation, we use "index inference" + # rather than naming a constraint directly. Currently, if an `include` param + # is provided to a django.models.Constraint, Django creates a UNIQUE INDEX instead + # of a CONSTRAINT, another eason to use "index inference" by just specifying columns. + constraints = {c.name: c for c in self.model._meta.constraints} + if isinstance(target, str): + if constraint := constraints.get(target): + target = constraint.fields + else: + target = [target] + elif not isinstance(target, list): + raise ValueError("`target` must be a string or a list.") + target = "({0})".format(', '.join(target)) + + # Convert to db_column names and set values from the `excluded` table + columns = ', '.join([ + "{0} = excluded.{0}".format( + self.model._meta.get_field(col).column + ) + for col in columns + ]) + action = "DO UPDATE SET {0}".format(columns) + else: + raise ValueError("Action must be one of 'ignore' or 'update'.") return """ - ON CONFLICT DO NOTHING; - """ + ON CONFLICT {0} {1}; + """.format(target, action) else: return ";" diff --git a/postgres_copy/managers.py b/postgres_copy/managers.py index 6df7244..2509d6a 100644 --- a/postgres_copy/managers.py +++ b/postgres_copy/managers.py @@ -57,12 +57,18 @@ def drop_constraints(self): # Remove any field constraints for field in self.constrained_fields: - logger.debug("Dropping constraints from {}".format(field)) + logger.debug("Dropping field constraint from {}".format(field)) field_copy = field.__copy__() field_copy.db_constraint = False args = (self.model, field, field_copy) self.edit_schema(schema_editor, 'alter_field', args) + # Remove remaining constraints + for constraint in getattr(self.model._meta, 'constraints', []): + logger.debug("Dropping constraint '{}'".format(constraint.name)) + args = (self.model, constraint) + self.edit_schema(schema_editor, 'remove_constraint', args) + def drop_indexes(self): """ Drop indexes on the model and its fields. @@ -70,19 +76,25 @@ def drop_indexes(self): logger.debug("Dropping indexes from {}".format(self.model.__name__)) with connection.schema_editor() as schema_editor: # Remove any "index_together" constraints - logger.debug("Dropping index_together of {}".format(self.model._meta.index_together)) if self.model._meta.index_together: + logger.debug("Dropping index_together of {}".format(self.model._meta.index_together)) args = (self.model, self.model._meta.index_together, ()) self.edit_schema(schema_editor, 'alter_index_together', args) # Remove any field indexes for field in self.indexed_fields: - logger.debug("Dropping index from {}".format(field)) + logger.debug("Dropping field index from {}".format(field)) field_copy = field.__copy__() field_copy.db_index = False args = (self.model, field, field_copy) self.edit_schema(schema_editor, 'alter_field', args) + # Remove remaining indexes + for index in getattr(self.model._meta, 'indexes', []): + logger.debug("Dropping index '{}'".format(index.name)) + args = (self.model, index) + self.edit_schema(schema_editor, 'remove_index', args) + def restore_constraints(self): """ Restore constraints on the model and its fields. @@ -95,14 +107,20 @@ def restore_constraints(self): args = (self.model, (), self.model._meta.unique_together) self.edit_schema(schema_editor, 'alter_unique_together', args) - # Add any constraints to the fields + # Add any field constraints for field in self.constrained_fields: - logger.debug("Adding constraints to {}".format(field)) + logger.debug("Adding field constraint to {}".format(field)) field_copy = field.__copy__() field_copy.db_constraint = False args = (self.model, field_copy, field) self.edit_schema(schema_editor, 'alter_field', args) + # Add remaining constraints + for constraint in getattr(self.model._meta, 'constraints', []): + logger.debug("Adding constraint '{}'".format(constraint.name)) + args = (self.model, constraint) + self.edit_schema(schema_editor, 'add_constraint', args) + def restore_indexes(self): """ Restore indexes on the model and its fields. @@ -117,12 +135,18 @@ def restore_indexes(self): # Add any indexes to the fields for field in self.indexed_fields: - logger.debug("Restoring index to {}".format(field)) + logger.debug("Restoring field index to {}".format(field)) field_copy = field.__copy__() field_copy.db_index = False args = (self.model, field_copy, field) self.edit_schema(schema_editor, 'alter_field', args) + # Add remaining indexes + for index in getattr(self.model._meta, 'indexes', []): + logger.debug("Adding index '{}'".format(index.name)) + args = (self.model, index) + self.edit_schema(schema_editor, 'add_index', args) + class CopyQuerySet(ConstraintQuerySet): """ @@ -146,6 +170,15 @@ def from_csv(self, csv_path, mapping=None, drop_constraints=True, drop_indexes=T "anyway. Either remove the transaction block, or set " "drop_constraints=False and drop_indexes=False.") + # NOTE: See GH Issue #117 + # We could remove this block if drop_constraints' default was False + if on_conflict := kwargs.get('on_conflict'): + if target := on_conflict.get('target'): + if target in [c.name for c in self.model._meta.constraints]: + drop_constraints = False + elif on_conflict.get('action') == 'ignore': + drop_constraints = False + mapping = CopyMapping(self.model, csv_path, mapping, **kwargs) if drop_constraints: diff --git a/tests/models.py b/tests/models.py index a330f29..f94597a 100644 --- a/tests/models.py +++ b/tests/models.py @@ -114,6 +114,35 @@ class SecondaryMockObject(models.Model): objects = CopyManager() -class UniqueMockObject(models.Model): +class UniqueFieldConstraintMockObject(models.Model): name = models.CharField(max_length=500, unique=True) objects = CopyManager() + + +class UniqueModelConstraintMockObject(models.Model): + name = models.CharField(max_length=500) + number = MyIntegerField(null=True, db_column='num') + objects = CopyManager() + + class Meta: + constraints = [ + models.UniqueConstraint( + name='constraint', + fields=['name'], + ), + ] + + +class UniqueModelConstraintAsIndexMockObject(models.Model): + name = models.CharField(max_length=500) + number = MyIntegerField(null=True, db_column='num') + objects = CopyManager() + + class Meta: + constraints = [ + models.UniqueConstraint( + name='constraint_as_index', + fields=['name'], + include=['number'], # Converts Constraint to Index + ), + ] diff --git a/tests/tests.py b/tests/tests.py index 4db9563..14667b4 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -12,7 +12,9 @@ OverloadMockObject, HookedCopyMapping, SecondaryMockObject, - UniqueMockObject + UniqueFieldConstraintMockObject, + UniqueModelConstraintMockObject, + UniqueModelConstraintAsIndexMockObject, ) from django.test import TestCase from django.db import transaction @@ -589,17 +591,93 @@ def test_encoding_save(self, _): @mock.patch("django.db.connection.validate_no_atomic_block") def test_ignore_conflicts(self, _): - UniqueMockObject.objects.from_csv( + UniqueFieldConstraintMockObject.objects.from_csv( self.name_path, dict(name='NAME'), ignore_conflicts=True ) - UniqueMockObject.objects.from_csv( + UniqueFieldConstraintMockObject.objects.from_csv( self.name_path, dict(name='NAME'), ignore_conflicts=True ) + @mock.patch("django.db.connection.validate_no_atomic_block") + def test_on_conflict_ignore(self, _): + UniqueModelConstraintMockObject.objects.from_csv( + self.name_path, + dict(name='NAME', number='NUMBER'), + on_conflict={'action': 'ignore'}, + ) + UniqueModelConstraintMockObject.objects.from_csv( + self.name_path, + dict(name='NAME', number='NUMBER'), + on_conflict={'action': 'ignore'}, + ) + + @mock.patch("django.db.connection.validate_no_atomic_block") + def test_on_conflict_target_field_update(self, _): + UniqueFieldConstraintMockObject.objects.from_csv( + self.name_path, + dict(name='NAME'), + on_conflict={ + 'action': 'update', + 'target': 'name', + 'columns': ['name'], + }, + ) + UniqueFieldConstraintMockObject.objects.from_csv( + self.name_path, + dict(name='NAME'), + on_conflict={ + 'action': 'update', + 'target': 'name', + 'columns': ['name'], + }, + ) + + @mock.patch("django.db.connection.validate_no_atomic_block") + def test_on_conflict_target_constraint_update(self, _): + UniqueModelConstraintMockObject.objects.from_csv( + self.name_path, + dict(name='NAME', number='NUMBER'), + on_conflict={ + 'action': 'update', + 'target': 'constraint', + 'columns': ['name', 'number'], + }, + ) + UniqueModelConstraintMockObject.objects.from_csv( + self.name_path, + dict(name='NAME', number='NUMBER'), + on_conflict={ + 'action': 'update', + 'target': 'constraint', + 'columns': ['name', 'number'], + }, + ) + + @mock.patch("django.db.connection.validate_no_atomic_block") + def test_on_conflict_target_constraint_as_index_update(self, _): + UniqueModelConstraintAsIndexMockObject.objects.from_csv( + self.name_path, + dict(name='NAME', number='NUMBER'), + on_conflict={ + 'action': 'update', + 'target': 'constraint_as_index', + 'columns': ['name', 'number'], + }, + ) + UniqueModelConstraintAsIndexMockObject.objects.from_csv( + self.name_path, + dict(name='NAME', number='NUMBER'), + on_conflict={ + 'action': 'update', + 'target': 'constraint_as_index', + 'columns': ['name', 'number'], + }, + ) + @mock.patch("django.db.connection.validate_no_atomic_block") def test_static_values(self, _): ExtendedMockObject.objects.from_csv(