Skip to content

Commit 01c393f

Browse files
wallneradamgrigi
authored andcommitted
EnumField and CharEnumField (tortoise#244)
1 parent f3ea664 commit 01c393f

File tree

6 files changed

+186
-0
lines changed

6 files changed

+186
-0
lines changed

docs/examples/basic.rst

+9
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,12 @@ Recursive Relations
9393
===================
9494
.. literalinclude:: ../../examples/relations_recursive.py
9595

96+
97+
.. rst-class:: html-toggle
98+
99+
.. _example_enum_fields:
100+
101+
===============
102+
Enumeration Fields
103+
===============
104+
.. literalinclude:: ../../examples/enum_fields.py

docs/fields.rst

+6
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ Data Fields
7070
.. autoclass:: tortoise.fields.UUIDField
7171
:exclude-members: to_db_value, to_python_value
7272

73+
.. autoclass:: tortoise.fields.IntEnumField
74+
:exclude-members: to_db_value, to_python_value
75+
76+
.. autoclass:: tortoise.fields.CharEnumField
77+
:exclude-members: to_db_value, to_python_value
78+
7379

7480
Relational Fields
7581
-----------------

examples/enum_fields.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import typing
2+
from enum import Enum, IntEnum
3+
4+
from tortoise import fields
5+
from tortoise.models import Model
6+
7+
8+
class Service(IntEnum):
9+
python_programming = 1
10+
database_design = 2
11+
system_administration = 3
12+
13+
14+
class Currency(str, Enum):
15+
HUF = "HUF"
16+
EUR = "EUR"
17+
USD = "USD"
18+
19+
20+
class EnumFields(Model):
21+
service: Service = typing.cast(Service, fields.IntEnumField(Service))
22+
currency: Currency = typing.cast(Currency, fields.CharEnumField(Currency, default=Currency.HUF))

tests/fields/test_enum.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from tests import testmodels
2+
from tortoise.contrib import test
3+
from tortoise.exceptions import IntegrityError
4+
5+
6+
class TestEnumFields(test.TestCase):
7+
async def test_empty(self):
8+
with self.assertRaises(IntegrityError):
9+
await testmodels.EnumFields.create()
10+
11+
async def test_create(self):
12+
obj0 = await testmodels.EnumFields.create(service=testmodels.Service.system_administration)
13+
self.assertIsInstance(obj0.service, testmodels.Service)
14+
self.assertIsInstance(obj0.currency, testmodels.Currency)
15+
obj = await testmodels.EnumFields.get(id=obj0.id)
16+
self.assertIsInstance(obj.service, testmodels.Service)
17+
self.assertIsInstance(obj.currency, testmodels.Currency)
18+
self.assertEqual(obj.service, testmodels.Service.system_administration)
19+
self.assertEqual(obj.currency, testmodels.Currency.HUF)
20+
await obj.save()
21+
obj2 = await testmodels.EnumFields.get(id=obj.id)
22+
self.assertEqual(obj, obj2)
23+
24+
await obj.delete()
25+
obj = await testmodels.EnumFields.filter(id=obj0.id).first()
26+
self.assertEqual(obj, None)
27+
28+
async def test_update(self):
29+
obj0 = await testmodels.EnumFields.create(service=testmodels.Service.system_administration)
30+
await testmodels.EnumFields.filter(id=obj0.id).update(
31+
service=testmodels.Service.database_design, currency=testmodels.Currency.EUR
32+
)
33+
obj = await testmodels.EnumFields.get(id=obj0.id)
34+
self.assertEqual(obj.service, testmodels.Service.database_design)
35+
self.assertEqual(obj.currency, testmodels.Currency.EUR)

tests/testmodels.py

+19
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
"""
44
import binascii
55
import os
6+
import typing
67
import uuid
8+
from enum import Enum, IntEnum
79

810
from tortoise import fields
911
from tortoise.models import Model
@@ -481,3 +483,20 @@ class Meta:
481483
table = "sometable"
482484
unique_together = [["chars", "blip"]]
483485
table_description = "Source mapped fields"
486+
487+
488+
class Service(IntEnum):
489+
python_programming = 1
490+
database_design = 2
491+
system_administration = 3
492+
493+
494+
class Currency(str, Enum):
495+
HUF = "HUF"
496+
EUR = "EUR"
497+
USD = "USD"
498+
499+
500+
class EnumFields(Model):
501+
service: Service = typing.cast(Service, fields.IntEnumField(Service))
502+
currency: Currency = typing.cast(Currency, fields.CharEnumField(Currency, default=Currency.HUF))

tortoise/fields.py

+95
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import functools
33
import json
44
from decimal import Decimal
5+
from enum import Enum, IntEnum
56
from typing import TYPE_CHECKING, Any, Awaitable, Generic, Optional, Type, TypeVar, Union
67
from uuid import UUID, uuid4
78

@@ -407,6 +408,100 @@ def to_python_value(self, value: Any) -> Optional[UUID]:
407408
return UUID(value)
408409

409410

411+
class IntEnumField(SmallIntField):
412+
"""
413+
Enum Field
414+
415+
A field representing an integer enumeration.
416+
417+
The description of the field is set automatically if not specified to a multiline list of
418+
"name: value" pairs.
419+
420+
``enum_type``:
421+
The enum class
422+
``description``:
423+
The description of the field. It is set automatically if not specified to a multiline list
424+
of "name: value" pairs.
425+
"""
426+
427+
__slots__ = ("enum_type",)
428+
429+
def __init__(
430+
self, enum_type: Type[IntEnum], description: Optional[str] = None, **kwargs
431+
) -> None:
432+
# Validate values
433+
for item in enum_type:
434+
try:
435+
value = int(item.value)
436+
except ValueError:
437+
raise ConfigurationError("IntEnumField only supports integer enums!")
438+
if not 0 <= value < 32768:
439+
raise ConfigurationError("The valid range of IntEnumField's values is 0..32767!")
440+
441+
# Automatic description for the field if not specified by the user
442+
if description is None:
443+
description = "\n".join([f"{e.name}: {int(e.value)}" for e in enum_type])[:2048]
444+
445+
super().__init__(description=description, **kwargs)
446+
self.enum_type = enum_type
447+
448+
def to_python_value(self, value: Union[int, None]) -> Union[IntEnum, None]:
449+
return self.enum_type(value) if value is not None else None
450+
451+
def to_db_value(self, value: Union[IntEnum, None], instance) -> Union[int, None]:
452+
return int(value.value) if value is not None else None
453+
454+
455+
class CharEnumField(CharField):
456+
"""
457+
Char Enum Field
458+
459+
A field representing a character enumeration.
460+
461+
**Warning**: If ``max_length`` is not specified or equals to zero, the size of represented
462+
char fields is automatically detected. So if later you update the enum, you need to update your
463+
table schema as well.
464+
465+
``enum_type``:
466+
The enum class
467+
``description``:
468+
The description of the field. It is set automatically if not specified to a multiline list
469+
of "name: value" pairs.
470+
``max_length``:
471+
The length of the created CharField. If it is zero it is automatically detected from
472+
enum_type.
473+
"""
474+
475+
__slots__ = ("enum_type",)
476+
477+
def __init__(
478+
self,
479+
enum_type: Type[Enum],
480+
description: Optional[str] = None,
481+
max_length: int = 0,
482+
**kwargs,
483+
) -> None:
484+
# Automatic description for the field if not specified by the user
485+
if description is None:
486+
description = "\n".join([f"{e.name}: {str(e.value)}" for e in enum_type])[:2048]
487+
488+
# Automatic CharField max_length
489+
if max_length == 0:
490+
for item in enum_type:
491+
item_len = len(str(item.value))
492+
if item_len > max_length:
493+
max_length = item_len
494+
495+
super().__init__(description=description, max_length=max_length, **kwargs)
496+
self.enum_type = enum_type
497+
498+
def to_python_value(self, value: Union[str, None]) -> Union[Enum, None]:
499+
return self.enum_type(value) if value is not None else None
500+
501+
def to_db_value(self, value: Union[Enum, None], instance) -> Union[str, None]:
502+
return str(value.value) if value is not None else None
503+
504+
410505
ForeignKeyNullableRelation = Union[Awaitable[Optional[MODEL]], Optional[MODEL]]
411506
"""
412507
Type hint for the result of accessing the :class:`.ForeignKeyField` field in the model

0 commit comments

Comments
 (0)