diff --git a/.gitignore b/.gitignore index 66ab9352d..29a28f696 100644 --- a/.gitignore +++ b/.gitignore @@ -86,6 +86,7 @@ celerybeat-schedule # dotenv .env +.env3 # virtualenv .venv diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 309b604dd..598225eed 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -7,6 +7,8 @@ Changelog - The ``auto_now_add`` argument of ``DatetimeField`` is handled correctly in the SQLite backend. - ``unique_together`` now creates named constrains, to prevent the DB from auto-assigning a potentially non-unique constraint name. - Filtering by an ``auto_now`` field doesn't replace the filter value with ``now()`` anymore. +- Implemented ``OneToOneField``, one to one relation between two models. +- Prefetching is done asynchronously now, sending all prefetch request at the same time instead of in sequence. 0.15.1 ------ diff --git a/CONTRIBUTORS.rst b/CONTRIBUTORS.rst index bf9d0fddf..7b79e27f1 100644 --- a/CONTRIBUTORS.rst +++ b/CONTRIBUTORS.rst @@ -21,6 +21,7 @@ Contributors * Adam Wallner ``@wallneradam`` * Zoltán Szeredi ``@zoliszeredi`` * Rebecca Klauser ``@svms1`` +* Sina Sohangir ``@sinaso`` Special Thanks ============== diff --git a/docs/fields.rst b/docs/fields.rst index 29068c969..38b74171d 100644 --- a/docs/fields.rst +++ b/docs/fields.rst @@ -77,6 +77,8 @@ Relational Fields .. autoclass:: tortoise.fields.ForeignKeyField :exclude-members: to_db_value, to_python_value +.. autoclass:: tortoise.fields.OneToOneField + .. autofunction:: tortoise.fields.ManyToManyField .. autodata:: tortoise.fields.ForeignKeyRelation diff --git a/examples/relations.py b/examples/relations.py index a1e0d6165..3dd34d9d1 100644 --- a/examples/relations.py +++ b/examples/relations.py @@ -34,6 +34,18 @@ def __str__(self): return self.name +class Address(Model): + city = fields.CharField(max_length=64) + street = fields.CharField(max_length=128) + + event: fields.OneToOneRelation[Event] = fields.OneToOneField( + "models.Event", on_delete=fields.CASCADE, related_name="address", pk=True + ) + + def __str__(self): + return f"Address({self.city}, {self.street})" + + class Team(Model): id = fields.IntField(pk=True) name = fields.TextField() @@ -53,6 +65,9 @@ async def run(): await Event(name="Without participants", tournament_id=tournament.id).save() event = Event(name="Test", tournament_id=tournament.id) await event.save() + + await Address.create(city="Santa Monica", street="Ocean", event=event) + participants = [] for i in range(2): team = Team(name=f"Team {(i + 1)}") @@ -96,6 +111,14 @@ async def run(): print(await Event.filter(id=event.id).values_list("id", "participants__name")) + print(await Address.filter(event=event).first()) + + event_reload1 = await Event.filter(id=event.id).first() + print(await event_reload1.address) + + event_reload2 = await Event.filter(id=event.id).prefetch_related("address").first() + print(event_reload2.address) + if __name__ == "__main__": run_async(run()) diff --git a/tests/model_bad_rel5.py b/tests/model_bad_rel5.py new file mode 100644 index 000000000..078199bd6 --- /dev/null +++ b/tests/model_bad_rel5.py @@ -0,0 +1,14 @@ +""" +Testing Models for a bad/wrong relation reference +Wrong reference. App missing. +""" +from tortoise import fields +from tortoise.models import Model + + +class Tournament(Model): + id = fields.IntField(pk=True) + + +class Event(Model): + tournament = fields.OneToOneField("Tournament") diff --git a/tests/models_dup3.py b/tests/models_dup3.py new file mode 100644 index 000000000..35e189609 --- /dev/null +++ b/tests/models_dup3.py @@ -0,0 +1,15 @@ +""" +This is the testing Models — Duplicate 3 +""" + +from tortoise import fields +from tortoise.models import Model + + +class Tournament(Model): + id = fields.IntField(pk=True) + event = fields.CharField(max_length=32) + + +class Event(Model): + tournament = fields.OneToOneField("models.Tournament", related_name="event") diff --git a/tests/models_o2o_2.py b/tests/models_o2o_2.py new file mode 100644 index 000000000..677406b5d --- /dev/null +++ b/tests/models_o2o_2.py @@ -0,0 +1,9 @@ +""" +This is the testing Models — Bad on_delete parameter +""" +from tortoise import fields +from tortoise.models import Model + + +class One(Model): + tournament = fields.OneToOneField("models.Two", on_delete="WABOOM") diff --git a/tests/models_o2o_3.py b/tests/models_o2o_3.py new file mode 100644 index 000000000..9247bf4b9 --- /dev/null +++ b/tests/models_o2o_3.py @@ -0,0 +1,9 @@ +""" +This is the testing Models — on_delete SET_NULL without null=True +""" +from tortoise import fields +from tortoise.models import Model + + +class One(Model): + tournament = fields.OneToOneField("models.Two", on_delete=fields.SET_NULL) diff --git a/tests/models_schema_create.py b/tests/models_schema_create.py index 7467a5733..035b3c73c 100644 --- a/tests/models_schema_create.py +++ b/tests/models_schema_create.py @@ -47,6 +47,22 @@ class Meta: indexes = [("manager", "key"), ["manager_id", "name"]] +class TeamAddress(Model): + city = fields.CharField(max_length=50, description="City") + country = fields.CharField(max_length=50, description="Country") + street = fields.CharField(max_length=128, description="Street Address") + team = fields.OneToOneField( + "models.Team", related_name="address", on_delete=fields.CASCADE, pk=True + ) + + +class VenueInformation(Model): + name = fields.CharField(max_length=128) + capacity = fields.IntField() + rent = fields.FloatField() + team = fields.OneToOneField("models.Team", on_delete=fields.SET_NULL, null=True) + + class SourceFields(Model): id = fields.IntField(pk=True, source_field="sometable_id") chars = fields.CharField(max_length=255, source_field="some_chars_table", index=True) diff --git a/tests/test_bad_relation_reference.py b/tests/test_bad_relation_reference.py index 0b1316529..f40009d4d 100644 --- a/tests/test_bad_relation_reference.py +++ b/tests/test_bad_relation_reference.py @@ -98,3 +98,24 @@ async def test_more_than_two_dots_in_reference_init(self): }, } ) + + async def test_no_app_in_o2o_reference_init(self): + with self.assertRaisesRegex( + ConfigurationError, 'OneToOneField accepts model name in format "app.Model"' + ): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, + } + }, + "apps": { + "models": { + "models": ["tests.model_bad_rel5"], + "default_connection": "default", + } + }, + } + ) diff --git a/tests/test_describe_model.py b/tests/test_describe_model.py index ec734097d..ae22fa760 100644 --- a/tests/test_describe_model.py +++ b/tests/test_describe_model.py @@ -101,6 +101,18 @@ async def test_describe_model_straight(self): "default": None, "description": "Tree!", }, + { + "db_column": "o2o_id", + "default": None, + "description": "Line", + "field_type": "IntField", + "generated": False, + "indexed": True, + "name": "o2o_id", + "nullable": True, + "python_type": "int", + "unique": True, + }, ], "fk_fields": [ { @@ -129,6 +141,33 @@ async def test_describe_model_straight(self): "description": "Tree!", } ], + "o2o_fields": [ + { + "default": None, + "description": "Line", + "field_type": "OneToOneField", + "generated": False, + "indexed": True, + "name": "o2o", + "nullable": True, + "python_type": "models.StraightFields", + "raw_field": "o2o_id", + "unique": True, + } + ], + "backward_o2o_fields": [ + { + "default": None, + "description": "Line", + "field_type": "BackwardOneToOneRelation", + "generated": False, + "indexed": False, + "name": "o2o_rev", + "nullable": True, + "python_type": "models.StraightFields", + "unique": False, + } + ], "m2m_fields": [ { "name": "rel_to", @@ -217,6 +256,18 @@ async def test_describe_model_straight_native(self): "default": None, "description": "Tree!", }, + { + "name": "o2o_id", + "field_type": fields.IntField, + "db_column": "o2o_id", + "python_type": int, + "generated": False, + "nullable": True, + "unique": True, + "indexed": True, + "default": None, + "description": "Line", + }, ], "fk_fields": [ { @@ -245,6 +296,33 @@ async def test_describe_model_straight_native(self): "description": "Tree!", } ], + "o2o_fields": [ + { + "default": None, + "description": "Line", + "field_type": fields.OneToOneField, + "generated": False, + "indexed": True, + "name": "o2o", + "nullable": True, + "python_type": StraightFields, + "raw_field": "o2o_id", + "unique": True, + }, + ], + "backward_o2o_fields": [ + { + "default": None, + "description": "Line", + "field_type": fields.BackwardOneToOneRelation, + "generated": False, + "indexed": False, + "name": "o2o_rev", + "nullable": True, + "python_type": StraightFields, + "unique": False, + }, + ], "m2m_fields": [ { "name": "rel_to", @@ -333,6 +411,18 @@ async def test_describe_model_source(self): "default": None, "description": "Tree!", }, + { + "name": "o2o_id", + "field_type": "IntField", + "db_column": "o2o_sometable", + "python_type": "int", + "generated": False, + "nullable": True, + "unique": True, + "indexed": True, + "default": None, + "description": "Line", + }, ], "fk_fields": [ { @@ -361,6 +451,33 @@ async def test_describe_model_source(self): "description": "Tree!", } ], + "o2o_fields": [ + { + "default": None, + "description": "Line", + "field_type": "OneToOneField", + "generated": False, + "indexed": True, + "name": "o2o", + "nullable": True, + "python_type": "models.SourceFields", + "raw_field": "o2o_id", + "unique": True, + } + ], + "backward_o2o_fields": [ + { + "default": None, + "description": "Line", + "field_type": "BackwardOneToOneRelation", + "generated": False, + "indexed": False, + "name": "o2o_rev", + "nullable": True, + "python_type": "models.SourceFields", + "unique": False, + } + ], "m2m_fields": [ { "name": "rel_to", @@ -449,6 +566,18 @@ async def test_describe_model_source_native(self): "default": None, "description": "Tree!", }, + { + "name": "o2o_id", + "field_type": fields.IntField, + "db_column": "o2o_sometable", + "python_type": int, + "generated": False, + "nullable": True, + "unique": True, + "indexed": True, + "default": None, + "description": "Line", + }, ], "fk_fields": [ { @@ -477,6 +606,33 @@ async def test_describe_model_source_native(self): "description": "Tree!", } ], + "o2o_fields": [ + { + "default": None, + "description": "Line", + "field_type": fields.OneToOneField, + "generated": False, + "indexed": True, + "name": "o2o", + "nullable": True, + "python_type": SourceFields, + "raw_field": "o2o_id", + "unique": True, + } + ], + "backward_o2o_fields": [ + { + "default": None, + "description": "Line", + "field_type": fields.BackwardOneToOneRelation, + "generated": False, + "indexed": False, + "name": "o2o_rev", + "nullable": True, + "python_type": SourceFields, + "unique": False, + } + ], "m2m_fields": [ { "name": "rel_to", @@ -554,6 +710,8 @@ async def test_describe_model_uuidpk(self): "description": None, }, ], + "o2o_fields": [], + "backward_o2o_fields": [], "m2m_fields": [ { "name": "peers", @@ -620,6 +778,8 @@ async def test_describe_model_uuidpk_native(self): "description": None, }, ], + "o2o_fields": [], + "backward_o2o_fields": [], "m2m_fields": [ { "name": "peers", @@ -700,6 +860,8 @@ async def test_describe_model_json(self): ], "fk_fields": [], "backward_fk_fields": [], + "o2o_fields": [], + "backward_o2o_fields": [], "m2m_fields": [], }, ) @@ -768,6 +930,8 @@ async def test_describe_model_json_native(self): ], "fk_fields": [], "backward_fk_fields": [], + "o2o_fields": [], + "backward_o2o_fields": [], "m2m_fields": [], }, ) diff --git a/tests/test_generate_schema.py b/tests/test_generate_schema.py index 90c13941e..ddc043e07 100644 --- a/tests/test_generate_schema.py +++ b/tests/test_generate_schema.py @@ -110,6 +110,18 @@ async def test_fk_bad_null(self): ): await self.init_for("tests.models_fk_3") + async def test_o2o_bad_on_delete(self): + with self.assertRaisesRegex( + ConfigurationError, "on_delete can only be CASCADE, RESTRICT or SET_NULL" + ): + await self.init_for("tests.models_o2o_2") + + async def test_o2o_bad_null(self): + with self.assertRaisesRegex( + ConfigurationError, "If on_delete is SET_NULL, then field must have null=True set" + ): + await self.init_for("tests.models_o2o_3") + async def test_m2m_bad_model_name(self): with self.assertRaisesRegex( ConfigurationError, 'Foreign key accepts model name in format "app.Model"' @@ -155,6 +167,12 @@ async def test_schema(self): ) /* The TEAMS! */; CREATE INDEX "idx_team_manager_676134" ON "team" ("manager_id", "key"); CREATE INDEX "idx_team_manager_ef8f69" ON "team" ("manager_id", "name"); +CREATE TABLE "teamaddress" ( + "city" VARCHAR(50) NOT NULL /* City */, + "country" VARCHAR(50) NOT NULL /* Country */, + "street" VARCHAR(128) NOT NULL /* Street Address */, + "team_id" VARCHAR(50) NOT NULL UNIQUE PRIMARY KEY REFERENCES "team" ("name") ON DELETE CASCADE +); CREATE TABLE "tournament" ( "tid" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, "name" VARCHAR(100) NOT NULL /* Tournament name */, @@ -172,6 +190,13 @@ async def test_schema(self): CONSTRAINT "uid_event_name_c6f89f" UNIQUE ("name", "prize"), CONSTRAINT "uid_event_tournam_a5b730" UNIQUE ("tournament_id", "key") ) /* This table contains a list of all the events */; +CREATE TABLE "venueinformation" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + "name" VARCHAR(128) NOT NULL, + "capacity" INT NOT NULL, + "rent" REAL NOT NULL, + "team_id" VARCHAR(50) UNIQUE REFERENCES "team" ("name") ON DELETE SET NULL +); CREATE TABLE "sometable_self" ( "backward_sts" INT NOT NULL REFERENCES "sometable" ("sometable_id") ON DELETE CASCADE, "sts_forward" INT NOT NULL REFERENCES "sometable" ("sometable_id") ON DELETE CASCADE @@ -219,6 +244,12 @@ async def test_schema_safe(self): ) /* The TEAMS! */; CREATE INDEX IF NOT EXISTS "idx_team_manager_676134" ON "team" ("manager_id", "key"); CREATE INDEX IF NOT EXISTS "idx_team_manager_ef8f69" ON "team" ("manager_id", "name"); +CREATE TABLE IF NOT EXISTS "teamaddress" ( + "city" VARCHAR(50) NOT NULL /* City */, + "country" VARCHAR(50) NOT NULL /* Country */, + "street" VARCHAR(128) NOT NULL /* Street Address */, + "team_id" VARCHAR(50) NOT NULL UNIQUE PRIMARY KEY REFERENCES "team" ("name") ON DELETE CASCADE +); CREATE TABLE IF NOT EXISTS "tournament" ( "tid" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, "name" VARCHAR(100) NOT NULL /* Tournament name */, @@ -236,6 +267,13 @@ async def test_schema_safe(self): CONSTRAINT "uid_event_name_c6f89f" UNIQUE ("name", "prize"), CONSTRAINT "uid_event_tournam_a5b730" UNIQUE ("tournament_id", "key") ) /* This table contains a list of all the events */; +CREATE TABLE IF NOT EXISTS "venueinformation" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + "name" VARCHAR(128) NOT NULL, + "capacity" INT NOT NULL, + "rent" REAL NOT NULL, + "team_id" VARCHAR(50) UNIQUE REFERENCES "team" ("name") ON DELETE SET NULL +); CREATE TABLE IF NOT EXISTS "sometable_self" ( "backward_sts" INT NOT NULL REFERENCES "sometable" ("sometable_id") ON DELETE CASCADE, "sts_forward" INT NOT NULL REFERENCES "sometable" ("sometable_id") ON DELETE CASCADE @@ -349,6 +387,13 @@ async def test_schema(self): KEY `idx_team_manager_676134` (`manager_id`, `key`), KEY `idx_team_manager_ef8f69` (`manager_id`, `name`) ) CHARACTER SET utf8mb4 COMMENT='The TEAMS!'; +CREATE TABLE `teamaddress` ( + `city` VARCHAR(50) NOT NULL COMMENT 'City', + `country` VARCHAR(50) NOT NULL COMMENT 'Country', + `street` VARCHAR(128) NOT NULL COMMENT 'Street Address', + `team_id` VARCHAR(50) NOT NULL UNIQUE PRIMARY KEY, + CONSTRAINT `fk_teamaddr_team_1c78d737` FOREIGN KEY (`team_id`) REFERENCES `team` (`name`) ON DELETE CASCADE +) CHARACTER SET utf8mb4; CREATE TABLE `tournament` ( `tid` SMALLINT NOT NULL PRIMARY KEY AUTO_INCREMENT, `name` VARCHAR(100) NOT NULL COMMENT 'Tournament name', @@ -367,6 +412,14 @@ async def test_schema(self): UNIQUE KEY `uid_event_tournam_a5b730` (`tournament_id`, `key`), CONSTRAINT `fk_event_tourname_51c2b82d` FOREIGN KEY (`tournament_id`) REFERENCES `tournament` (`tid`) ON DELETE CASCADE ) CHARACTER SET utf8mb4 COMMENT='This table contains a list of all the events'; +CREATE TABLE `venueinformation` ( + `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT, + `name` VARCHAR(128) NOT NULL, + `capacity` INT NOT NULL, + `rent` DOUBLE NOT NULL, + `team_id` VARCHAR(50) UNIQUE, + CONSTRAINT `fk_venueinf_team_198af929` FOREIGN KEY (`team_id`) REFERENCES `team` (`name`) ON DELETE SET NULL +) CHARACTER SET utf8mb4; CREATE TABLE `sometable_self` ( `backward_sts` INT NOT NULL, `sts_forward` INT NOT NULL, @@ -423,6 +476,13 @@ async def test_schema_safe(self): KEY `idx_team_manager_676134` (`manager_id`, `key`), KEY `idx_team_manager_ef8f69` (`manager_id`, `name`) ) CHARACTER SET utf8mb4 COMMENT='The TEAMS!'; +CREATE TABLE IF NOT EXISTS `teamaddress` ( + `city` VARCHAR(50) NOT NULL COMMENT 'City', + `country` VARCHAR(50) NOT NULL COMMENT 'Country', + `street` VARCHAR(128) NOT NULL COMMENT 'Street Address', + `team_id` VARCHAR(50) NOT NULL UNIQUE PRIMARY KEY, + CONSTRAINT `fk_teamaddr_team_1c78d737` FOREIGN KEY (`team_id`) REFERENCES `team` (`name`) ON DELETE CASCADE +) CHARACTER SET utf8mb4; CREATE TABLE IF NOT EXISTS `tournament` ( `tid` SMALLINT NOT NULL PRIMARY KEY AUTO_INCREMENT, `name` VARCHAR(100) NOT NULL COMMENT 'Tournament name', @@ -441,6 +501,14 @@ async def test_schema_safe(self): UNIQUE KEY `uid_event_tournam_a5b730` (`tournament_id`, `key`), CONSTRAINT `fk_event_tourname_51c2b82d` FOREIGN KEY (`tournament_id`) REFERENCES `tournament` (`tid`) ON DELETE CASCADE ) CHARACTER SET utf8mb4 COMMENT='This table contains a list of all the events'; +CREATE TABLE IF NOT EXISTS `venueinformation` ( + `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT, + `name` VARCHAR(128) NOT NULL, + `capacity` INT NOT NULL, + `rent` DOUBLE NOT NULL, + `team_id` VARCHAR(50) UNIQUE, + CONSTRAINT `fk_venueinf_team_198af929` FOREIGN KEY (`team_id`) REFERENCES `team` (`name`) ON DELETE SET NULL +) CHARACTER SET utf8mb4; CREATE TABLE IF NOT EXISTS `sometable_self` ( `backward_sts` INT NOT NULL, `sts_forward` INT NOT NULL, @@ -541,6 +609,15 @@ async def test_schema(self): CREATE INDEX "idx_team_manager_ef8f69" ON "team" ("manager_id", "name"); COMMENT ON COLUMN "team"."name" IS 'The TEAM name (and PK)'; COMMENT ON TABLE "team" IS 'The TEAMS!'; +CREATE TABLE "teamaddress" ( + "city" VARCHAR(50) NOT NULL, + "country" VARCHAR(50) NOT NULL, + "street" VARCHAR(128) NOT NULL, + "team_id" VARCHAR(50) NOT NULL UNIQUE PRIMARY KEY REFERENCES "team" ("name") ON DELETE CASCADE +); +COMMENT ON COLUMN "teamaddress"."city" IS 'City'; +COMMENT ON COLUMN "teamaddress"."country" IS 'Country'; +COMMENT ON COLUMN "teamaddress"."street" IS 'Street Address'; CREATE TABLE "tournament" ( "tid" SMALLSERIAL NOT NULL PRIMARY KEY, "name" VARCHAR(100) NOT NULL, @@ -565,6 +642,13 @@ async def test_schema(self): COMMENT ON COLUMN "event"."token" IS 'Unique token'; COMMENT ON COLUMN "event"."tournament_id" IS 'FK to tournament'; COMMENT ON TABLE "event" IS 'This table contains a list of all the events'; +CREATE TABLE "venueinformation" ( + "id" SERIAL NOT NULL PRIMARY KEY, + "name" VARCHAR(128) NOT NULL, + "capacity" INT NOT NULL, + "rent" DOUBLE PRECISION NOT NULL, + "team_id" VARCHAR(50) UNIQUE REFERENCES "team" ("name") ON DELETE SET NULL +); CREATE TABLE "sometable_self" ( "backward_sts" INT NOT NULL REFERENCES "sometable" ("sometable_id") ON DELETE CASCADE, "sts_forward" INT NOT NULL REFERENCES "sometable" ("sometable_id") ON DELETE CASCADE @@ -615,6 +699,15 @@ async def test_schema_safe(self): CREATE INDEX IF NOT EXISTS "idx_team_manager_ef8f69" ON "team" ("manager_id", "name"); COMMENT ON COLUMN "team"."name" IS 'The TEAM name (and PK)'; COMMENT ON TABLE "team" IS 'The TEAMS!'; +CREATE TABLE IF NOT EXISTS "teamaddress" ( + "city" VARCHAR(50) NOT NULL, + "country" VARCHAR(50) NOT NULL, + "street" VARCHAR(128) NOT NULL, + "team_id" VARCHAR(50) NOT NULL UNIQUE PRIMARY KEY REFERENCES "team" ("name") ON DELETE CASCADE +); +COMMENT ON COLUMN "teamaddress"."city" IS 'City'; +COMMENT ON COLUMN "teamaddress"."country" IS 'Country'; +COMMENT ON COLUMN "teamaddress"."street" IS 'Street Address'; CREATE TABLE IF NOT EXISTS "tournament" ( "tid" SMALLSERIAL NOT NULL PRIMARY KEY, "name" VARCHAR(100) NOT NULL, @@ -639,6 +732,13 @@ async def test_schema_safe(self): COMMENT ON COLUMN "event"."token" IS 'Unique token'; COMMENT ON COLUMN "event"."tournament_id" IS 'FK to tournament'; COMMENT ON TABLE "event" IS 'This table contains a list of all the events'; +CREATE TABLE IF NOT EXISTS "venueinformation" ( + "id" SERIAL NOT NULL PRIMARY KEY, + "name" VARCHAR(128) NOT NULL, + "capacity" INT NOT NULL, + "rent" DOUBLE PRECISION NOT NULL, + "team_id" VARCHAR(50) UNIQUE REFERENCES "team" ("name") ON DELETE SET NULL +); CREATE TABLE IF NOT EXISTS "sometable_self" ( "backward_sts" INT NOT NULL REFERENCES "sometable" ("sometable_id") ON DELETE CASCADE, "sts_forward" INT NOT NULL REFERENCES "sometable" ("sometable_id") ON DELETE CASCADE diff --git a/tests/test_init.py b/tests/test_init.py index 0256064cb..a518c22ae 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -72,6 +72,24 @@ async def test_dup2_init(self): } ) + async def test_dup3_init(self): + with self.assertRaisesRegex( + ConfigurationError, 'backward relation "event" duplicates in model Tournament' + ): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, + } + }, + "apps": { + "models": {"models": ["tests.models_dup3"], "default_connection": "default"} + }, + } + ) + async def test_generated_nonint(self): with self.assertRaisesRegex( ConfigurationError, "Generated primary key allowed only for IntField and BigIntField" diff --git a/tests/test_model_methods.py b/tests/test_model_methods.py index 82c83f6f5..db65697a6 100644 --- a/tests/test_model_methods.py +++ b/tests/test_model_methods.py @@ -1,4 +1,4 @@ -from tests.testmodels import Event, NoID, Team, Tournament +from tests.testmodels import Address, Event, NoID, Team, Tournament from tortoise.contrib import test from tortoise.exceptions import ( ConfigurationError, @@ -153,6 +153,15 @@ def test_rev_m2m(self): ): Team(name="a", events=[]) + async def test_rev_o2o(self): + with self.assertRaisesRegex( + ConfigurationError, + "You can't set backward one to one relations through init, " + "change related model instead", + ): + address = await Address.create(city="Santa Monica", street="Ocean") + await Event(name="a", address=address) + def test_fk_unsaved(self): with self.assertRaisesRegex(OperationalError, "You should first call .save()"): Event(name="a", tournament=Tournament(name="a")) diff --git a/tests/test_prefetching.py b/tests/test_prefetching.py index 2c39cb8ab..e79ff3958 100644 --- a/tests/test_prefetching.py +++ b/tests/test_prefetching.py @@ -1,4 +1,4 @@ -from tests.testmodels import Event, Team, Tournament +from tests.testmodels import Address, Event, Team, Tournament from tortoise.contrib import test from tortoise.exceptions import FieldError, OperationalError from tortoise.functions import Count @@ -52,6 +52,15 @@ async def test_prefetch_m2m(self): ) self.assertEqual(len(fetched_events.participants), 1) + async def test_prefetch_o2o(self): + tournament = await Tournament.create(name="tournament") + event = await Event.create(name="First", tournament=tournament) + await Address.create(city="Santa Monica", street="Ocean", event=event) + + fetched_events = await Event.all().prefetch_related("address").first() + + self.assertEqual(fetched_events.address.city, "Santa Monica") + async def test_prefetch_nested(self): tournament = await Tournament.create(name="tournament") event = await Event.create(name="First", tournament=tournament) diff --git a/tests/test_relations.py b/tests/test_relations.py index 67fd4eb82..21506c480 100644 --- a/tests/test_relations.py +++ b/tests/test_relations.py @@ -1,4 +1,4 @@ -from tests.testmodels import Employee, Event, Team, Tournament +from tests.testmodels import Address, Employee, Event, Team, Tournament from tortoise.contrib import test from tortoise.exceptions import FieldError, NoValuesFetched from tortoise.functions import Count @@ -135,6 +135,14 @@ async def test_m2m_remove(self): fetched_event = await Event.first().prefetch_related("participants") self.assertEqual(len(fetched_event.participants), 1) + async def test_o2o_lazy(self): + tournament = await Tournament.create(name="tournament") + event = await Event.create(name="First", tournament=tournament) + await Address.create(city="Santa Monica", street="Ocean", event=event) + + fetched_address = await event.address + self.assertEqual(fetched_address.city, "Santa Monica") + async def test_m2m_remove_two(self): tournament = await Tournament.create(name="tournament") event = await Event.create(name="First", tournament=tournament) diff --git a/tests/testmodels.py b/tests/testmodels.py index a550b383d..10b04c3af 100644 --- a/tests/testmodels.py +++ b/tests/testmodels.py @@ -62,6 +62,15 @@ def __str__(self): return self.name +class Address(Model): + city = fields.CharField(max_length=64) + street = fields.CharField(max_length=128) + + event: fields.OneToOneRelation[Event] = fields.OneToOneField( + "models.Event", on_delete=fields.CASCADE, related_name="address", null=True + ) + + class Team(Model): id = fields.IntField(pk=True) name = fields.TextField() @@ -446,6 +455,11 @@ class StraightFields(Model): ) fkrev: fields.ReverseRelation["StraightFields"] + o2o: fields.OneToOneNullableRelation["StraightFields"] = fields.OneToOneField( + "models.StraightFields", related_name="o2o_rev", null=True, description="Line" + ) + o2o_rev: fields.Field + rel_to: fields.ManyToManyRelation["StraightFields"] = fields.ManyToManyField( "models.StraightFields", related_name="rel_from", description="M2M to myself" ) @@ -472,6 +486,15 @@ class SourceFields(Model): ) fkrev: fields.ReverseRelation["SourceFields"] + o2o: fields.OneToOneNullableRelation["SourceFields"] = fields.OneToOneField( + "models.SourceFields", + related_name="o2o_rev", + null=True, + source_field="o2o_sometable", + description="Line", + ) + o2o_rev: fields.Field + rel_to: fields.ManyToManyRelation["SourceFields"] = fields.ManyToManyField( "models.SourceFields", related_name="rel_from", diff --git a/tortoise/__init__.py b/tortoise/__init__.py index c2e07c604..a5b1c7231 100644 --- a/tortoise/__init__.py +++ b/tortoise/__init__.py @@ -74,6 +74,8 @@ def describe_model(cls, model: Type[Model], serializable: bool = True) -> dict: "data_fields": [...] # Data fields "fk_fields": [...] # Foreign Key fields FROM this model "backward_fk_fields": [...] # Foreign Key fields TO this model + "o2o_fields": [...] # OneToOne fields FROM this model + "backward_o2o_fields": [...] # OneToOne fields TO this model "m2m_fields": [...] # Many-to-Many fields } @@ -160,14 +162,21 @@ def describe_field(name: str) -> dict: } # Foreign Keys have - if isinstance(field, fields.ForeignKeyField): + if isinstance(field, (fields.ForeignKeyField, fields.OneToOneField)): del desc["db_column"] desc["raw_field"] = field.source_field else: del desc["raw_field"] # These fields are entierly "virtual", so no direct DB representation - if isinstance(field, (fields.ManyToManyFieldInstance, fields.BackwardFKRelation)): + if isinstance( + field, + ( + fields.ManyToManyFieldInstance, + fields.BackwardFKRelation, + fields.BackwardOneToOneRelation, + ), + ): del desc["db_column"] return desc @@ -196,6 +205,16 @@ def describe_field(name: str) -> dict: for name in model._meta.fields_map.keys() if name in model._meta.backward_fk_fields ], + "o2o_fields": [ + describe_field(name) + for name in model._meta.fields_map.keys() + if name in model._meta.o2o_fields + ], + "backward_o2o_fields": [ + describe_field(name) + for name in model._meta.fields_map.keys() + if name in model._meta.backward_o2o_fields + ], "m2m_fields": [ describe_field(name) for name in model._meta.fields_map.keys() @@ -283,6 +302,8 @@ def split_reference(reference: str) -> Tuple[str, str]: if not model._meta.table: model._meta.table = model.__name__.lower() + pk_attr_changed = False + for field in model._meta.fk_fields: fk_object = cast(fields.ForeignKeyField, model._meta.fields_map[field]) reference = fk_object.model_name @@ -320,6 +341,48 @@ def split_reference(reference: str) -> Tuple[str, str]: ) related_model._meta.add_field(backward_relation_name, fk_relation) + for field in model._meta.o2o_fields: + o2o_object = cast(fields.OneToOneField, model._meta.fields_map[field]) + reference = o2o_object.model_name + related_app_name, related_model_name = split_reference(reference) + related_model = get_related_model(related_app_name, related_model_name) + + key_field = f"{field}_id" + key_o2o_object = deepcopy(related_model._meta.pk) + key_o2o_object.pk = o2o_object.pk + key_o2o_object.index = o2o_object.index + key_o2o_object.default = o2o_object.default + key_o2o_object.null = o2o_object.null + key_o2o_object.unique = o2o_object.unique + key_o2o_object.generated = o2o_object.generated + key_o2o_object.reference = o2o_object + key_o2o_object.description = o2o_object.description + if o2o_object.source_field: + key_o2o_object.source_field = o2o_object.source_field + o2o_object.source_field = key_field + else: + o2o_object.source_field = key_field + key_o2o_object.source_field = key_field + model._meta.add_field(key_field, key_o2o_object) + + o2o_object.field_type = related_model + backward_relation_name = o2o_object.related_name + if not backward_relation_name: + backward_relation_name = f"{model._meta.table}" + if backward_relation_name in related_model._meta.fields: + raise ConfigurationError( + f'backward relation "{backward_relation_name}" duplicates in' + f" model {related_model_name}" + ) + o2o_relation = fields.BackwardOneToOneRelation( + model, f"{field}_id", null=True, description=o2o_object.description + ) + related_model._meta.add_field(backward_relation_name, o2o_relation) + + if o2o_object.pk: + pk_attr_changed = True + model._meta.pk_attr = key_field + for field in list(model._meta.m2m_fields): m2m_object = cast(fields.ManyToManyFieldInstance, model._meta.fields_map[field]) if m2m_object._generated: @@ -340,9 +403,7 @@ def split_reference(reference: str) -> Tuple[str, str]: backward_relation_name = m2m_object.related_name if not backward_relation_name: - backward_relation_name = ( - m2m_object.related_name - ) = f"{model._meta.table}_through" + backward_relation_name = m2m_object.related_name = f"{model._meta.table}s" if backward_relation_name in related_model._meta.fields: raise ConfigurationError( f'backward relation "{backward_relation_name}" duplicates in' @@ -371,6 +432,9 @@ def split_reference(reference: str) -> Tuple[str, str]: model._meta.filters.update(get_m2m_filters(field, m2m_object)) related_model._meta.add_field(backward_relation_name, m2m_relation) + if pk_attr_changed: + model._meta.finalise_pk() + @classmethod def _discover_client_class(cls, engine: str) -> BaseDBAsyncClient: # Let exception bubble up for transparency diff --git a/tortoise/backends/base/executor.py b/tortoise/backends/base/executor.py index a9d0b585e..c81911a5b 100644 --- a/tortoise/backends/base/executor.py +++ b/tortoise/backends/base/executor.py @@ -1,3 +1,4 @@ +import asyncio import datetime import decimal from copy import copy @@ -203,6 +204,29 @@ async def _prefetch_reverse_relation( relation_container._set_result_for_query(related_object_map.get(instance.pk, [])) return instance_list + async def _prefetch_reverse_o2o_relation( + self, instance_list: list, field: str, related_query + ) -> list: + instance_id_set: set = { + self._field_to_db(instance._meta.pk, instance.pk, instance) + for instance in instance_list + } + relation_field = self.model._meta.fields_map[field].relation_field # type: ignore + + related_object_list = await related_query.filter( + **{f"{relation_field}__in": list(instance_id_set)} + ) + + related_object_map: dict = {} + for entry in related_object_list: + object_id = getattr(entry, relation_field) + related_object_map[object_id] = entry + + for instance in instance_list: + setattr(instance, f"_{field}", related_object_map.get(instance.pk, None)) + + return instance_list + async def _prefetch_m2m_relation(self, instance_list: list, field: str, related_query) -> list: instance_id_set: set = { self._field_to_db(instance._meta.pk, instance.pk, instance) @@ -318,6 +342,10 @@ def _make_prefetch_queries(self) -> None: async def _do_prefetch(self, instance_id_list: list, field: str, related_query) -> list: if field in self.model._meta.backward_fk_fields: return await self._prefetch_reverse_relation(instance_id_list, field, related_query) + + if field in self.model._meta.backward_o2o_fields: + return await self._prefetch_reverse_o2o_relation(instance_id_list, field, related_query) + if field in self.model._meta.m2m_fields: return await self._prefetch_m2m_relation(instance_id_list, field, related_query) return await self._prefetch_direct_relation(instance_id_list, field, related_query) @@ -325,8 +353,12 @@ async def _do_prefetch(self, instance_id_list: list, field: str, related_query) async def _execute_prefetch_queries(self, instance_list: list) -> list: if instance_list and (self.prefetch_map or self._prefetch_queries): self._make_prefetch_queries() - for field, related_query in self._prefetch_queries.items(): - await self._do_prefetch(instance_list, field, related_query) + prefetch_tasks = [ + self._do_prefetch(instance_list, field, related_query) + for field, related_query in self._prefetch_queries.items() + ] + await asyncio.gather(*prefetch_tasks) + return instance_list async def fetch_for_list(self, instance_list: list, *args) -> list: diff --git a/tortoise/backends/base/schema_generator.py b/tortoise/backends/base/schema_generator.py index 7c6afa64d..e0f83d680 100644 --- a/tortoise/backends/base/schema_generator.py +++ b/tortoise/backends/base/schema_generator.py @@ -4,6 +4,7 @@ from tortoise import fields from tortoise.exceptions import ConfigurationError +from tortoise.fields import OneToOneField from tortoise.utils import get_escape_translation_table # pylint: disable=R0201 @@ -187,11 +188,12 @@ def _get_table_sql(self, model, safe=True) -> dict: else "" ) # TODO: PK generation needs to move out of schema generator. - if field_object.pk: + if field_object.pk and not isinstance(field_object.reference, OneToOneField): pk_string = self._get_primary_key_create_string(field_object, db_field, comment) if pk_string: fields_to_create.append(pk_string) continue + nullable = "NOT NULL" if not field_object.null else "" unique = "UNIQUE" if field_object.unique else "" diff --git a/tortoise/backends/sqlite/client.py b/tortoise/backends/sqlite/client.py index 7d034d155..d03f7ae85 100644 --- a/tortoise/backends/sqlite/client.py +++ b/tortoise/backends/sqlite/client.py @@ -44,6 +44,7 @@ def __init__(self, file_path: str, **kwargs) -> None: self.pragmas = kwargs.copy() self.pragmas.pop("connection_name", None) self.pragmas.pop("fetch_inserted", None) + self.pragmas["foreign_keys"] = "ON" self._connection: Optional[aiosqlite.Connection] = None self._lock = asyncio.Lock() diff --git a/tortoise/fields.py b/tortoise/fields.py index 6f1e6ecfd..f16793f75 100644 --- a/tortoise/fields.py +++ b/tortoise/fields.py @@ -478,6 +478,78 @@ def __await__(self): # pragma: nocoverage ... # pylint: disable=W0104 +OneToOneNullableRelation = Union[Awaitable[Optional[MODEL]], Optional[MODEL]] +""" +Type hint for the result of accessing the :class:`.OneToOneField` field in the model +when obtained model can be nullable. +""" + +OneToOneRelation = Union[Awaitable[MODEL], MODEL] +""" +Type hint for the result of accessing the :class:`.OneToOneField` field in the model. +""" + + +class OneToOneField(Field): + """ + OneToOne relation field. + + This field represents a one to one relation to another model. + + You must provide the following: + + ``model_name``: + The name of the related model in a :samp:`'{app}.{model}'` format. + + The following is optional: + + ``related_name``: + The attribute name on the related model to reverse resolve the one to one relation. + ``on_delete``: + One of: + ``field.CASCADE``: + Indicate that the model should be cascade deleted if related model gets deleted. + ``field.RESTRICT``: + Indicate that the related model delete will be restricted as long as a + one to one relation points to it. + ``field.SET_NULL``: + Resets the field to NULL in case the related model gets deleted. + Can only be set if field has ``null=True`` set. + ``field.SET_DEFAULT``: + Resets the field to ``default`` value in case the related model gets deleted. + Can only be set is field has a ``default`` set. + """ + + __slots__ = ( + "field_type", + # type will be set later, so we need a slot to be able to write it + "model_name", + "related_name", + "on_delete", + ) + has_db_field = False + + def __init__( + self, model_name: str, related_name: Optional[str] = None, on_delete=CASCADE, **kwargs + ) -> None: + kwargs["unique"] = True + super().__init__(**kwargs) + # self.field_type: "Type[Model]" = None # type: ignore + if len(model_name.split(".")) != 2: + raise ConfigurationError('OneToOneField accepts model name in format "app.Model"') + self.model_name = model_name + self.related_name = related_name + if on_delete not in {CASCADE, RESTRICT, SET_NULL}: + raise ConfigurationError("on_delete can only be CASCADE, RESTRICT or SET_NULL") + if on_delete == SET_NULL and not bool(kwargs.get("null")): + raise ConfigurationError("If on_delete is SET_NULL, then field must have null=True set") + self.on_delete = on_delete + + # we need this for IDEs so that they don't say that the field is not awaitable + def __await__(self): + ... # pylint: disable=W0104 + + class ManyToManyFieldInstance(Field): __slots__ = ( "field_type", # Here we need type to be able to set dyamically @@ -565,6 +637,10 @@ def __init__( self.description = description +class BackwardOneToOneRelation(BackwardFKRelation): + pass + + class ReverseRelation(Generic[MODEL]): """ Relation container for :class:`.ForeignKeyField`. diff --git a/tortoise/models.py b/tortoise/models.py index 85ea2715f..817df530c 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -47,6 +47,15 @@ def _rfk_getter(self, _key, ftype, frelfield): return val +def _ro2o_getter(self, _key, ftype, frelfield): + if hasattr(self, _key): + return getattr(self, _key) + + val = ftype.filter(**{frelfield: self.pk}).first() + setattr(self, _key, val) + return val + + def _m2m_getter(self, _key, field_object): val = getattr(self, _key, None) if val is None: @@ -63,6 +72,8 @@ class MetaInfo: "fields", "db_fields", "m2m_fields", + "o2o_fields", + "backward_o2o_fields", "fk_fields", "backward_fk_fields", "fetch_fields", @@ -99,7 +110,9 @@ def __init__(self, meta) -> None: self.db_fields: Set[str] = set() self.m2m_fields: Set[str] = set() self.fk_fields: Set[str] = set() + self.o2o_fields: Set[str] = set() self.backward_fk_fields: Set[str] = set() + self.backward_o2o_fields: Set[str] = set() self.fetch_fields: Set[str] = set() self.fields_db_projection: Dict[str, str] = {} self.fields_db_projection_reverse: Dict[str, str] = {} @@ -132,6 +145,8 @@ def add_field(self, name: str, value: Field): if isinstance(value, fields.ManyToManyFieldInstance): self.m2m_fields.add(name) + elif isinstance(value, fields.BackwardOneToOneRelation): + self.backward_o2o_fields.add(name) elif isinstance(value, fields.BackwardFKRelation): self.backward_fk_fields.add(name) @@ -170,7 +185,13 @@ def finalise_fields(self) -> None: self.fields_db_projection_reverse = { value: key for key, value in self.fields_db_projection.items() } - self.fetch_fields = self.m2m_fields | self.backward_fk_fields | self.fk_fields + self.fetch_fields = ( + self.m2m_fields + | self.backward_fk_fields + | self.fk_fields + | self.backward_o2o_fields + | self.o2o_fields + ) generated_fields = [] for field in self.fields_map.values(): @@ -216,6 +237,42 @@ def _generate_lazy_fk_m2m_fields(self) -> None: ), ) + # Create lazy one to one fields on model. + for key in self.o2o_fields: + _key = f"_{key}" + relation_field = self.fields_map[key].source_field + setattr( + self._model, + key, + property( + partial( + _fk_getter, + _key=_key, + ftype=self.fields_map[key].field_type, + relation_field=relation_field, + ), + partial(_fk_setter, _key=_key, relation_field=relation_field), + partial(_fk_setter, value=None, _key=_key, relation_field=relation_field), + ), + ) + + # Create lazy reverse one to one fields on model. + for key in self.backward_o2o_fields: + _key = f"_{key}" + field_object: fields.BackwardOneToOneRelation = self.fields_map[key] # type: ignore + setattr( + self._model, + key, + property( + partial( + _ro2o_getter, + _key=_key, + ftype=field_object.field_type, + frelfield=field_object.relation_field, + ), + ), + ) + # Create lazy M2M fields on model. for key in self.m2m_fields: _key = f"_{key}" @@ -268,6 +325,7 @@ def __new__(mcs, name: str, bases, attrs: dict, *args, **kwargs): filters: Dict[str, Dict[str, dict]] = {} fk_fields: Set[str] = set() m2m_fields: Set[str] = set() + o2o_fields: Set[str] = set() meta_class = attrs.get("Meta", type("Meta", (), {})) pk_attr: str = "id" @@ -350,6 +408,8 @@ def __search_for_field_attributes(base, attrs: dict): if isinstance(value, fields.ForeignKeyField): fk_fields.add(key) + elif isinstance(value, fields.OneToOneField): + o2o_fields.add(key) elif isinstance(value, fields.ManyToManyFieldInstance): m2m_fields.add(key) else: @@ -380,6 +440,8 @@ def __search_for_field_attributes(base, attrs: dict): meta._filters = filters meta.fk_fields = fk_fields meta.backward_fk_fields = set() + meta.o2o_fields = o2o_fields + meta.backward_o2o_fields = set() meta.m2m_fields = m2m_fields meta.default_connection = None meta.pk_attr = pk_attr @@ -409,7 +471,7 @@ def __init__(self, **kwargs) -> None: passed_fields = {*kwargs.keys()} | meta.fetch_fields for key, value in kwargs.items(): - if key in meta.fk_fields: + if key in meta.fk_fields or key in meta.o2o_fields: if value and not value._saved_in_db: raise OperationalError( f"You should first call .save() on {value} before referring to it" @@ -427,6 +489,11 @@ def __init__(self, **kwargs) -> None: raise ConfigurationError( "You can't set backward relations through init, change related model instead" ) + elif key in meta.backward_o2o_fields: + raise ConfigurationError( + "You can't set backward one to one relations through init," + " change related model instead" + ) elif key in meta.m2m_fields: raise ConfigurationError( "You can't set m2m relations through init, use m2m_manager instead" diff --git a/tortoise/query_utils.py b/tortoise/query_utils.py index 6f0ab7e26..89bceb573 100644 --- a/tortoise/query_utils.py +++ b/tortoise/query_utils.py @@ -250,7 +250,7 @@ def _resolve_regular_kwarg(self, model, key, value) -> QueryModifier: return modifier def _get_actual_filter_params(self, model, key, value) -> Tuple[str, Any]: - if key in model._meta.fk_fields: + if key in model._meta.fk_fields or key in model._meta.o2o_fields: field_object = model._meta.fields_map[key] if hasattr(value, "pk"): filter_value = value.pk diff --git a/tortoise/queryset.py b/tortoise/queryset.py index f90ed4159..38a2b298a 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -562,7 +562,7 @@ def _make_query(self) -> None: raise FieldError(f"Unknown keyword argument {key} for model {self.model}") if field_object.pk: raise IntegrityError(f"Field {key} is PK and can not be updated") - if isinstance(field_object, fields.ForeignKeyField): + if isinstance(field_object, (fields.ForeignKeyField, fields.OneToOneField)): fk_field: str = field_object.source_field # type: ignore db_field = self.model._meta.fields_map[fk_field].source_field value = executor.column_map[fk_field](value.pk, None)