diff --git a/docs/src/piccolo/query_types/objects.rst b/docs/src/piccolo/query_types/objects.rst index 0d6f93d26..691ed55c2 100644 --- a/docs/src/piccolo/query_types/objects.rst +++ b/docs/src/piccolo/query_types/objects.rst @@ -361,6 +361,54 @@ It works with ``prefetch`` too: ------------------------------------------------------------------------------- +Comparing objects +----------------- + +If you have two objects, and you want to know whether they refer to the same +row in the database, you can simply use the equality operator: + +.. code-block:: python + + band_1 = await Band.objects().where(Band.name == "Pythonistas").first() + band_2 = await Band.objects().where(Band.name == "Pythonistas").first() + + >>> band_1 == band_2 + True + +It works by comparing the primary key value of each object. It's equivalent to +this: + +.. code-block:: python + + >>> band_1.id == band_2.id + True + +If the object has no primary key value yet (e.g. it uses a ``Serial`` column, +and it hasn't been saved in the database), then the result will always be +``False``: + +.. code-block:: python + + band_1 = Band() + band_2 = Band() + + >>> band_1 == band_2 + False + +If you want to compare every value on the objects, and not just the primary +key, you can use ``to_dict``. For example: + +.. code-block:: python + + >>> band_1.to_dict() == band_2.to_dict() + True + + >>> band_1.popularity = 10_000 + >>> band_1.to_dict() == band_2.to_dict() + False + +------------------------------------------------------------------------------- + Query clauses ------------- diff --git a/piccolo/columns/column_types.py b/piccolo/columns/column_types.py index e80fb254a..3f0a0e3f9 100644 --- a/piccolo/columns/column_types.py +++ b/piccolo/columns/column_types.py @@ -773,7 +773,7 @@ def column_type(self): return "INTEGER" raise Exception("Unrecognized engine type") - def default(self): + def default(self) -> QueryString: engine_type = self._meta.engine_type if engine_type == "postgres": diff --git a/piccolo/table.py b/piccolo/table.py index bae9b8a47..e5345d5c2 100644 --- a/piccolo/table.py +++ b/piccolo/table.py @@ -851,6 +851,72 @@ def __repr__(self) -> str: ) return f"<{self.__class__.__name__}: {pk}>" + def __eq__(self, other: t.Any) -> bool: + """ + Lets us check if two ``Table`` instances represent the same row in the + database, based on their primary key value:: + + band_1 = await Band.objects().where( + Band.name == "Pythonistas" + ).first() + + band_2 = await Band.objects().where( + Band.name == "Pythonistas" + ).first() + + band_3 = await Band.objects().where( + Band.name == "Rustaceans" + ).first() + + >>> band_1 == band_2 + True + + >>> band_1 == band_3 + False + + """ + if not isinstance(other, Table): + # This is the correct way to tell Python that this operation + # isn't supported: + # https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + # Make sure we're comparing the same table. + # There are several ways we could do this (like comparing tablename), + # but this should be OK. + if not isinstance(other, self.__class__): + return False + + pk = self._meta.primary_key + + pk_value = getattr( + self, + pk._meta.name, + ) + + other_pk_value = getattr( + other, + pk._meta.name, + ) + + # Make sure the primary key values are of the correct type. + # We need this for `Serial` columns, which have a `QueryString` + # value until saved in the database. We don't want to use `==` on + # two QueryString values, because QueryString has a custom `__eq__` + # method which doesn't return a boolean. + if isinstance( + pk_value, + pk.value_type, + ) and isinstance( + other_pk_value, + pk.value_type, + ): + return pk_value == other_pk_value + else: + # As a fallback, even if it hasn't been saved in the database, + # an object should still be equal to itself. + return other is self + ########################################################################### # Classmethods diff --git a/tests/table/instance/test_equality.py b/tests/table/instance/test_equality.py new file mode 100644 index 000000000..40ae59517 --- /dev/null +++ b/tests/table/instance/test_equality.py @@ -0,0 +1,74 @@ +from piccolo.columns.column_types import UUID, Varchar +from piccolo.table import Table +from piccolo.testing.test_case import AsyncTableTest +from tests.example_apps.music.tables import Manager + + +class ManagerUUID(Table): + id = UUID(primary_key=True) + name = Varchar() + + +class TestInstanceEquality(AsyncTableTest): + tables = [ + Manager, + ManagerUUID, + ] + + async def test_instance_equality(self) -> None: + """ + Make sure instance equality works, for tables with a `Serial` primary + key. + """ + manager_1 = Manager(name="Guido") + await manager_1.save() + + manager_2 = Manager(name="Graydon") + await manager_2.save() + + self.assertEqual(manager_1, manager_1) + self.assertNotEqual(manager_1, manager_2) + + # Try fetching the row from the database. + manager_1_from_db = ( + await Manager.objects().where(Manager.id == manager_1.id).first() + ) + self.assertEqual(manager_1, manager_1_from_db) + self.assertNotEqual(manager_2, manager_1_from_db) + + # Try rows which haven't been saved yet. + # They have no primary key value (because they use Serial columns + # as the primary key), so they shouldn't be equal. + self.assertNotEqual(Manager(), Manager()) + self.assertNotEqual(manager_1, Manager()) + + # Make sure an object is equal to itself, even if not saved. + manager_unsaved = Manager() + self.assertEqual(manager_unsaved, manager_unsaved) + + async def test_instance_equality_uuid(self) -> None: + """ + Make sure instance equality works, for tables with a `UUID` primary + key. + """ + manager_1 = ManagerUUID(name="Guido") + await manager_1.save() + + manager_2 = ManagerUUID(name="Graydon") + await manager_2.save() + + self.assertEqual(manager_1, manager_1) + self.assertNotEqual(manager_1, manager_2) + + # Try fetching the row from the database. + manager_1_from_db = ( + await ManagerUUID.objects() + .where(ManagerUUID.id == manager_1.id) + .first() + ) + self.assertEqual(manager_1, manager_1_from_db) + self.assertNotEqual(manager_2, manager_1_from_db) + + # Make sure an object is equal to itself, even if not saved. + manager_unsaved = ManagerUUID() + self.assertEqual(manager_unsaved, manager_unsaved)