Skip to content

Commit 8cfab26

Browse files
authored
fix(RFC): Use metaclass for safe DType attr access (#2025)
* fix(RFC): Use metaclass for safe `DType` attr access Mentioned in -#1991 (comment) - #1807 (comment) * chore: add `_DurationMeta` Both `Duration` and `Datetime` are working with `polars` now. From this point it should just be reducing code for all the other backends * refactor: upgrade `_pandas` * refactor: upgrade `_arrow` * refactor: "upgrade" `_duckdb` They're all noops, but good to keep consistent * refactor: upgrade `_spark_like` * chore: remove comment moved to https://github.com/narwhals-dev/narwhals/pull/2025/files#r1958596925 * refactor: simplify `__eq__` The metaclass is much narrower than `type` previously * fix: maybe fix typo "dt_time_unit" Fixes #2025 (comment)
1 parent 663995d commit 8cfab26

File tree

6 files changed

+41
-35
lines changed

6 files changed

+41
-35
lines changed

narwhals/_arrow/utils.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from narwhals._arrow.typing import Incomplete
3030
from narwhals._arrow.typing import StringArray
3131
from narwhals.dtypes import DType
32-
from narwhals.typing import TimeUnit
3332
from narwhals.typing import _AnyDArray
3433
from narwhals.utils import Version
3534

@@ -182,12 +181,9 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> pa
182181
if isinstance_or_issubclass(dtype, dtypes.Categorical):
183182
return pa.dictionary(pa.uint32(), pa.string())
184183
if isinstance_or_issubclass(dtype, dtypes.Datetime):
185-
time_unit: TimeUnit = getattr(dtype, "time_unit", "us")
186-
time_zone = getattr(dtype, "time_zone", None)
187-
return pa.timestamp(time_unit, tz=time_zone)
184+
return pa.timestamp(dtype.time_unit, tz=dtype.time_zone) # type: ignore[arg-type]
188185
if isinstance_or_issubclass(dtype, dtypes.Duration):
189-
time_unit = getattr(dtype, "time_unit", "us")
190-
return pa.duration(time_unit)
186+
return pa.duration(dtype.time_unit)
191187
if isinstance_or_issubclass(dtype, dtypes.Date):
192188
return pa.date32()
193189
if isinstance_or_issubclass(dtype, dtypes.List):

narwhals/_duckdb/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,12 +176,12 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> st
176176
msg = "Categorical not supported by DuckDB"
177177
raise NotImplementedError(msg)
178178
if isinstance_or_issubclass(dtype, dtypes.Datetime):
179-
_time_unit = getattr(dtype, "time_unit", "us")
180-
_time_zone = getattr(dtype, "time_zone", None)
179+
_time_unit = dtype.time_unit
180+
_time_zone = dtype.time_zone
181181
msg = "todo"
182182
raise NotImplementedError(msg)
183183
if isinstance_or_issubclass(dtype, dtypes.Duration): # pragma: no cover
184-
_time_unit = getattr(dtype, "time_unit", "us")
184+
_time_unit = dtype.time_unit
185185
msg = "todo"
186186
raise NotImplementedError(msg)
187187
if isinstance_or_issubclass(dtype, dtypes.Date): # pragma: no cover

narwhals/_pandas_like/utils.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -613,28 +613,27 @@ def narwhals_to_native_dtype( # noqa: PLR0915
613613
# convert to it?
614614
return "category"
615615
if isinstance_or_issubclass(dtype, dtypes.Datetime):
616-
dt_time_unit = getattr(dtype, "time_unit", "us")
617-
dt_time_zone = getattr(dtype, "time_zone", None)
618-
619616
# Pandas does not support "ms" or "us" time units before version 2.0
620-
# Let's overwrite with "ns"
621617
if implementation is Implementation.PANDAS and backend_version < (
622618
2,
623619
): # pragma: no cover
624620
dt_time_unit = "ns"
621+
else:
622+
dt_time_unit = dtype.time_unit
625623

626624
if dtype_backend == "pyarrow":
627-
tz_part = f", tz={dt_time_zone}" if dt_time_zone else ""
625+
tz_part = f", tz={tz}" if (tz := dtype.time_zone) else ""
628626
return f"timestamp[{dt_time_unit}{tz_part}][pyarrow]"
629627
else:
630-
tz_part = f", {dt_time_zone}" if dt_time_zone else ""
628+
tz_part = f", {tz}" if (tz := dtype.time_zone) else ""
631629
return f"datetime64[{dt_time_unit}{tz_part}]"
632630
if isinstance_or_issubclass(dtype, dtypes.Duration):
633-
du_time_unit = getattr(dtype, "time_unit", "us")
634631
if implementation is Implementation.PANDAS and backend_version < (
635632
2,
636633
): # pragma: no cover
637-
dt_time_unit = "ns"
634+
du_time_unit = "ns"
635+
else:
636+
du_time_unit = dtype.time_unit
638637
return (
639638
f"duration[{du_time_unit}][pyarrow]"
640639
if dtype_backend == "pyarrow"

narwhals/_polars/utils.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from narwhals.exceptions import NarwhalsError
1515
from narwhals.exceptions import ShapeError
1616
from narwhals.utils import import_dtypes_module
17+
from narwhals.utils import isinstance_or_issubclass
1718

1819
if TYPE_CHECKING:
1920
from narwhals._polars.dataframe import PolarsDataFrame
@@ -190,13 +191,10 @@ def narwhals_to_native_dtype(
190191
if dtype == dtypes.Decimal:
191192
msg = "Casting to Decimal is not supported yet."
192193
raise NotImplementedError(msg)
193-
if dtype == dtypes.Datetime or isinstance(dtype, dtypes.Datetime):
194-
dt_time_unit: TimeUnit = getattr(dtype, "time_unit", "us")
195-
dt_time_zone = getattr(dtype, "time_zone", None)
196-
return pl.Datetime(dt_time_unit, dt_time_zone) # type: ignore[arg-type]
197-
if dtype == dtypes.Duration or isinstance(dtype, dtypes.Duration):
198-
du_time_unit: TimeUnit = getattr(dtype, "time_unit", "us")
199-
return pl.Duration(time_unit=du_time_unit) # type: ignore[arg-type]
194+
if isinstance_or_issubclass(dtype, dtypes.Datetime):
195+
return pl.Datetime(dtype.time_unit, dtype.time_zone) # type: ignore[arg-type]
196+
if isinstance_or_issubclass(dtype, dtypes.Duration):
197+
return pl.Duration(dtype.time_unit) # type: ignore[arg-type]
200198
if dtype == dtypes.List:
201199
return pl.List(narwhals_to_native_dtype(dtype.inner, version, backend_version)) # type: ignore[union-attr]
202200
if dtype == dtypes.Struct:

narwhals/_spark_like/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def narwhals_to_native_dtype(
116116
if isinstance_or_issubclass(dtype, dtypes.Date):
117117
return spark_types.DateType()
118118
if isinstance_or_issubclass(dtype, dtypes.Datetime):
119-
dt_time_zone = getattr(dtype, "time_zone", None)
119+
dt_time_zone = dtype.time_zone
120120
if dt_time_zone is None:
121121
return spark_types.TimestampNTZType()
122122
if dt_time_zone != "UTC": # pragma: no cover

narwhals/dtypes.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,17 @@ class Unknown(DType):
448448
"""
449449

450450

451-
class Datetime(TemporalType):
451+
class _DatetimeMeta(type):
452+
@property
453+
def time_unit(cls) -> TimeUnit:
454+
return "us"
455+
456+
@property
457+
def time_zone(cls) -> str | None:
458+
return None
459+
460+
461+
class Datetime(TemporalType, metaclass=_DatetimeMeta):
452462
"""Data type representing a calendar date and time of day.
453463
454464
Arguments:
@@ -505,11 +515,11 @@ def __init__(
505515
time_zone = str(time_zone)
506516

507517
self.time_unit: TimeUnit = time_unit
508-
self.time_zone = time_zone
518+
self.time_zone: str | None = time_zone
509519

510520
def __eq__(self: Self, other: object) -> bool:
511521
# allow comparing object instances to class
512-
if type(other) is type and issubclass(other, self.__class__):
522+
if type(other) is _DatetimeMeta:
513523
return True
514524
elif isinstance(other, self.__class__):
515525
return self.time_unit == other.time_unit and self.time_zone == other.time_zone
@@ -524,7 +534,13 @@ def __repr__(self: Self) -> str: # pragma: no cover
524534
return f"{class_name}(time_unit={self.time_unit!r}, time_zone={self.time_zone!r})"
525535

526536

527-
class Duration(TemporalType):
537+
class _DurationMeta(type):
538+
@property
539+
def time_unit(cls) -> TimeUnit:
540+
return "us"
541+
542+
543+
class Duration(TemporalType, metaclass=_DurationMeta):
528544
"""Data type representing a time duration.
529545
530546
Arguments:
@@ -552,22 +568,19 @@ class Duration(TemporalType):
552568
Duration(time_unit='ms')
553569
"""
554570

555-
def __init__(
556-
self: Self,
557-
time_unit: TimeUnit = "us",
558-
) -> None:
571+
def __init__(self: Self, time_unit: TimeUnit = "us") -> None:
559572
if time_unit not in ("s", "ms", "us", "ns"):
560573
msg = (
561574
"invalid `time_unit`"
562575
f"\n\nExpected one of {{'ns','us','ms', 's'}}, got {time_unit!r}."
563576
)
564577
raise ValueError(msg)
565578

566-
self.time_unit = time_unit
579+
self.time_unit: TimeUnit = time_unit
567580

568581
def __eq__(self: Self, other: object) -> bool:
569582
# allow comparing object instances to class
570-
if type(other) is type and issubclass(other, self.__class__):
583+
if type(other) is _DurationMeta:
571584
return True
572585
elif isinstance(other, self.__class__):
573586
return self.time_unit == other.time_unit

0 commit comments

Comments
 (0)