Skip to content
48 changes: 48 additions & 0 deletions docs/src/piccolo/query_types/objects.rst
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,54 @@ It works with ``prefetch`` too:

-------------------------------------------------------------------------------

Comparing objects
-----------------

If you have two objects, and you want to know whether they refer to the same
row in the database, you can simply use the equality operator:

.. code-block:: python

band_1 = await Band.objects().where(Band.name == "Pythonistas").first()
band_2 = await Band.objects().where(Band.name == "Pythonistas").first()

>>> band_1 == band_2
True

It works by comparing the primary key value of each object. It's equivalent to
this:

.. code-block:: python

>>> band_1.id == band_2.id
True

If the object has no primary key value yet (e.g. it uses a ``Serial`` column,
and it hasn't been saved in the database), then the result will always be
``False``:

.. code-block:: python

band_1 = Band()
band_2 = Band()

>>> band_1 == band_2
False

If you want to compare every value on the objects, and not just the primary
key, you can use ``to_dict``. For example:

.. code-block:: python

>>> band_1.to_dict() == band_2.to_dict()
True

>>> band_1.popularity = 10_000
>>> band_1.to_dict() == band_2.to_dict()
False

-------------------------------------------------------------------------------

Query clauses
-------------

Expand Down
2 changes: 1 addition & 1 deletion piccolo/columns/column_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@ def column_type(self):
return "INTEGER"
raise Exception("Unrecognized engine type")

def default(self):
def default(self) -> QueryString:
engine_type = self._meta.engine_type

if engine_type == "postgres":
Expand Down
66 changes: 66 additions & 0 deletions piccolo/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,72 @@ def __repr__(self) -> str:
)
return f"<{self.__class__.__name__}: {pk}>"

def __eq__(self, other: t.Any) -> bool:
"""
Lets us check if two ``Table`` instances represent the same row in the
database, based on their primary key value::

band_1 = await Band.objects().where(
Band.name == "Pythonistas"
).first()

band_2 = await Band.objects().where(
Band.name == "Pythonistas"
).first()

band_3 = await Band.objects().where(
Band.name == "Rustaceans"
).first()

>>> band_1 == band_2
True

>>> band_1 == band_3
False

"""
if not isinstance(other, Table):
# This is the correct way to tell Python that this operation
# isn't supported:
# https://docs.python.org/3/library/constants.html#NotImplemented
return NotImplemented

# Make sure we're comparing the same table.
# There are several ways we could do this (like comparing tablename),
# but this should be OK.
if not isinstance(other, self.__class__):
return False

pk = self._meta.primary_key

pk_value = getattr(
self,
pk._meta.name,
)

other_pk_value = getattr(
other,
pk._meta.name,
)

# Make sure the primary key values are of the correct type.
# We need this for `Serial` columns, which have a `QueryString`
# value until saved in the database. We don't want to use `==` on
# two QueryString values, because QueryString has a custom `__eq__`
# method which doesn't return a boolean.
if isinstance(
pk_value,
pk.value_type,
) and isinstance(
other_pk_value,
pk.value_type,
):
return pk_value == other_pk_value
else:
# As a fallback, even if it hasn't been saved in the database,
# an object should still be equal to itself.
return other is self

###########################################################################
# Classmethods

Expand Down
74 changes: 74 additions & 0 deletions tests/table/instance/test_equality.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from piccolo.columns.column_types import UUID, Varchar
from piccolo.table import Table
from piccolo.testing.test_case import AsyncTableTest
from tests.example_apps.music.tables import Manager


class ManagerUUID(Table):
id = UUID(primary_key=True)
name = Varchar()


class TestInstanceEquality(AsyncTableTest):
tables = [
Manager,
ManagerUUID,
]

async def test_instance_equality(self) -> None:
"""
Make sure instance equality works, for tables with a `Serial` primary
key.
"""
manager_1 = Manager(name="Guido")
await manager_1.save()

manager_2 = Manager(name="Graydon")
await manager_2.save()

self.assertEqual(manager_1, manager_1)
self.assertNotEqual(manager_1, manager_2)

# Try fetching the row from the database.
manager_1_from_db = (
await Manager.objects().where(Manager.id == manager_1.id).first()
)
self.assertEqual(manager_1, manager_1_from_db)
self.assertNotEqual(manager_2, manager_1_from_db)

# Try rows which haven't been saved yet.
# They have no primary key value (because they use Serial columns
# as the primary key), so they shouldn't be equal.
self.assertNotEqual(Manager(), Manager())
self.assertNotEqual(manager_1, Manager())

# Make sure an object is equal to itself, even if not saved.
manager_unsaved = Manager()
self.assertEqual(manager_unsaved, manager_unsaved)

async def test_instance_equality_uuid(self) -> None:
"""
Make sure instance equality works, for tables with a `UUID` primary
key.
"""
manager_1 = ManagerUUID(name="Guido")
await manager_1.save()

manager_2 = ManagerUUID(name="Graydon")
await manager_2.save()

self.assertEqual(manager_1, manager_1)
self.assertNotEqual(manager_1, manager_2)

# Try fetching the row from the database.
manager_1_from_db = (
await ManagerUUID.objects()
.where(ManagerUUID.id == manager_1.id)
.first()
)
self.assertEqual(manager_1, manager_1_from_db)
self.assertNotEqual(manager_2, manager_1_from_db)

# Make sure an object is equal to itself, even if not saved.
manager_unsaved = ManagerUUID()
self.assertEqual(manager_unsaved, manager_unsaved)