Skip to content

Commit b239832

Browse files
authored
fix: Create FlexibleEnumMeta enum metaclass (#74)
1 parent bfbd5e8 commit b239832

File tree

10 files changed

+102
-17
lines changed

10 files changed

+102
-17
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
kind: Fixes
2+
body: Added `FlexibleEnumMeta` to make enums non-breaking when we add a new value to the API.
3+
time: 2025-03-27T18:41:48.001433+01:00

dbtsl/models/base.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
import warnings
33
from dataclasses import dataclass, fields, is_dataclass
44
from dataclasses import field as dc_field
5+
from enum import EnumMeta
56
from functools import cache
67
from types import MappingProxyType
7-
from typing import Any, ClassVar, Dict, List, Set, Type, Union
8+
from typing import Any, ClassVar, Dict, List, Set, Tuple, Type, Union
89
from typing import get_args as get_type_args
910
from typing import get_origin as get_type_origin
1011

@@ -20,6 +21,35 @@ def snake_case_to_camel_case(s: str) -> str:
2021
return tokens[0] + "".join(t.title() for t in tokens[1:])
2122

2223

24+
class FlexibleEnumMeta(EnumMeta):
25+
"""Makes an Enum class not break if you provide it an unknown value."""
26+
27+
_subclass_registry: ClassVar[Set[str]] = set()
28+
29+
UNKNOWN = "UNKNOWN"
30+
31+
def __new__(metacls: Type["FlexibleEnumMeta"], name: str, bases: Tuple[Type], namespace: Dict[str, Any], **kwargs):
32+
"""Overwrite the _missing_ method of enum classes."""
33+
msg = f"Class {name} needs UNKNOWN attribute with 'UNKNOWN' string value"
34+
assert namespace.get("UNKNOWN", None) == "UNKNOWN", msg
35+
36+
metacls._subclass_registry.add(name)
37+
38+
newclass = super().__new__(metacls, name, bases, namespace) # pyright: ignore[reportArgumentType]
39+
setattr(newclass, "_missing_", classmethod(metacls._missing_)) # pyright: ignore[reportArgumentType]
40+
return newclass
41+
42+
def __getitem__(cls, name: str) -> Any:
43+
"""Return the UNKNOWN attribute if can't find value in class."""
44+
try:
45+
return super().__getitem__(name)
46+
except KeyError:
47+
return cls.UNKNOWN
48+
49+
def _missing_(cls, _name: str) -> str:
50+
return cls.UNKNOWN
51+
52+
2353
class BaseModel(DataClassDictMixin):
2454
"""Base class for all serializable models.
2555

dbtsl/models/dimension.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
from enum import Enum
33
from typing import List, Optional
44

5-
from dbtsl.models.base import BaseModel, GraphQLFragmentMixin
5+
from dbtsl.models.base import BaseModel, FlexibleEnumMeta, GraphQLFragmentMixin
66
from dbtsl.models.time import TimeGranularity
77

88

9-
class DimensionType(str, Enum):
9+
class DimensionType(Enum, metaclass=FlexibleEnumMeta):
1010
"""The type of a dimension."""
1111

12+
UNKNOWN = "UNKNOWN"
1213
CATEGORICAL = "CATEGORICAL"
1314
TIME = "TIME"
1415

dbtsl/models/entity.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
from enum import Enum
33
from typing import Optional
44

5-
from dbtsl.models.base import BaseModel, GraphQLFragmentMixin
5+
from dbtsl.models.base import BaseModel, FlexibleEnumMeta, GraphQLFragmentMixin
66

77

8-
class EntityType(str, Enum):
8+
class EntityType(Enum, metaclass=FlexibleEnumMeta):
99
"""All supported entity types."""
1010

11+
UNKNOWN = "UNKNOWN"
1112
FOREIGN = "FOREIGN"
1213
NATURAL = "NATURAL"
1314
PRIMARY = "PRIMARY"

dbtsl/models/measure.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
from enum import Enum
33
from typing import Optional
44

5-
from dbtsl.models.base import BaseModel, GraphQLFragmentMixin
5+
from dbtsl.models.base import BaseModel, FlexibleEnumMeta, GraphQLFragmentMixin
66

77

8-
class AggregationType(str, Enum):
8+
class AggregationType(Enum, metaclass=FlexibleEnumMeta):
99
"""All supported aggregation functions."""
1010

11+
UNKNOWN = "UNKNOWN"
1112
SUM = "SUM"
1213
MIN = "MIN"
1314
MAX = "MAX"

dbtsl/models/metric.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,17 @@
22
from enum import Enum
33
from typing import List, Optional
44

5-
from dbtsl.models.base import BaseModel, GraphQLFragmentMixin
5+
from dbtsl.models.base import BaseModel, FlexibleEnumMeta, GraphQLFragmentMixin
66
from dbtsl.models.dimension import Dimension
77
from dbtsl.models.entity import Entity
88
from dbtsl.models.measure import Measure
99
from dbtsl.models.time import TimeGranularity
1010

1111

12-
class MetricType(str, Enum):
12+
class MetricType(Enum, metaclass=FlexibleEnumMeta):
1313
"""The type of a Metric."""
1414

15+
UNKNOWN = "UNKNOWN"
1516
SIMPLE = "SIMPLE"
1617
RATIO = "RATIO"
1718
CUMULATIVE = "CUMULATIVE"

dbtsl/models/query.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66

77
import pyarrow as pa
88

9-
from dbtsl.models.base import BaseModel, GraphQLFragmentMixin
9+
from dbtsl.models.base import BaseModel, FlexibleEnumMeta, GraphQLFragmentMixin
1010

1111
QueryId = NewType("QueryId", str)
1212

1313

14-
class QueryStatus(str, Enum):
14+
class QueryStatus(Enum, metaclass=FlexibleEnumMeta):
1515
"""All the possible states of a query."""
1616

17+
UNKNOWN = "UNKNOWN"
1718
PENDING = "PENDING"
1819
RUNNING = "RUNNING"
1920
COMPILED = "COMPILED"

dbtsl/models/saved_query.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
from enum import Enum
44
from typing import List, Optional
55

6-
from dbtsl.models.base import BaseModel, GraphQLFragmentMixin
6+
from dbtsl.models.base import BaseModel, FlexibleEnumMeta, GraphQLFragmentMixin
77
from dbtsl.models.time import DatePart, TimeGranularity
88

99

10-
class ExportDestinationType(str, Enum):
10+
class ExportDestinationType(Enum, metaclass=FlexibleEnumMeta):
1111
"""All kinds of export destinations."""
1212

13+
UNKNOWN = "UNKNOWN"
1314
TABLE = "TABLE"
1415
VIEW = "VIEW"
1516

dbtsl/models/time.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
from typing_extensions import override
44

5-
from dbtsl.models.base import DeprecatedMixin
5+
from dbtsl.models.base import DeprecatedMixin, FlexibleEnumMeta
66

77

8-
class TimeGranularity(str, DeprecatedMixin, Enum):
8+
class TimeGranularity(DeprecatedMixin, Enum, metaclass=FlexibleEnumMeta):
99
"""A time granularity."""
1010

1111
@override
@@ -16,6 +16,7 @@ def _deprecation_message(cls) -> str:
1616
"Please just use strings to represent time grains."
1717
)
1818

19+
UNKNOWN = "UNKNOWN"
1920
NANOSECOND = "NANOSECOND"
2021
MICROSECOND = "MICROSECOND"
2122
MILLISECOND = "MILLISECOND"
@@ -29,9 +30,10 @@ def _deprecation_message(cls) -> str:
2930
YEAR = "YEAR"
3031

3132

32-
class DatePart(str, Enum):
33+
class DatePart(Enum, metaclass=FlexibleEnumMeta):
3334
"""Date part."""
3435

36+
UNKNOWN = "UNKNOWN"
3537
DOY = "DOY"
3638
DOW = "DOW"
3739
DAY = "DAY"

tests/test_models.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1+
import inspect
12
import warnings
23
from dataclasses import dataclass
34
from dataclasses import field as dc_field
5+
from enum import Enum
46
from typing import List
57

68
import pytest
79
from mashumaro.codecs.basic import decode
810
from typing_extensions import override
911

12+
import dbtsl.models as ALL_EXPORTED_MODELS
1013
from dbtsl.api.graphql.util import normalize_query
1114
from dbtsl.api.shared.query_params import (
1215
AdhocQueryParametersStrict,
@@ -17,7 +20,7 @@
1720
validate_order_by,
1821
validate_query_parameters,
1922
)
20-
from dbtsl.models.base import BaseModel, DeprecatedMixin, GraphQLFragmentMixin
23+
from dbtsl.models.base import BaseModel, DeprecatedMixin, FlexibleEnumMeta, GraphQLFragmentMixin
2124
from dbtsl.models.base import snake_case_to_camel_case as stc
2225

2326

@@ -29,6 +32,47 @@ def test_snake_case_to_camel_case() -> None:
2932
assert stc("helloWorld") == "helloWorld"
3033

3134

35+
def test_FlexibleEnumMeta_parse_unknown_value() -> None:
36+
"""Make sure FlexibleEnumMeta classes parse unknown values without error."""
37+
38+
class EnumTest(Enum, metaclass=FlexibleEnumMeta):
39+
A = "A"
40+
B = "B"
41+
UNKNOWN = "UNKNOWN"
42+
43+
assert EnumTest("A") == EnumTest.A
44+
assert EnumTest("B") == EnumTest.B
45+
assert EnumTest("test") == EnumTest.UNKNOWN
46+
47+
48+
def test_FlexibleEnumMeta_subclass_with_invalid_unknown_attribute() -> None:
49+
"""Make sure we'll raise an error whenever a flexible enum isn't declared properly."""
50+
with pytest.raises(AssertionError):
51+
52+
class EnumTestNoUnknown(Enum, metaclass=FlexibleEnumMeta):
53+
A = "A"
54+
55+
_ = EnumTestNoUnknown
56+
57+
with pytest.raises(AssertionError):
58+
59+
class EnumTestInvalidUnknown(Enum, metaclass=FlexibleEnumMeta):
60+
A = "A"
61+
UNKNOWN = "invalid_value"
62+
63+
_ = EnumTestInvalidUnknown
64+
65+
66+
def test_all_enum_models_are_flexible() -> None:
67+
"""Make sure we didn't forget to make any enum type flexible."""
68+
exported_enum_classes = inspect.getmembers(
69+
ALL_EXPORTED_MODELS, lambda member: (inspect.isclass(member) and issubclass(member, Enum))
70+
)
71+
for enum_class_name, _ in exported_enum_classes:
72+
msg = f"Enum {enum_class_name} needs to have FlexibleEnumMeta metaclass."
73+
assert enum_class_name in FlexibleEnumMeta._subclass_registry, msg
74+
75+
3276
def test_base_model_auto_alias() -> None:
3377
@dataclass
3478
class SubModel(BaseModel):

0 commit comments

Comments
 (0)