Skip to content

Commit 97fb90b

Browse files
authored
(fix): disallow NumpyExtensionArray (#10334)
1 parent 60bc816 commit 97fb90b

File tree

5 files changed

+66
-10
lines changed

5 files changed

+66
-10
lines changed

properties/test_pandas_roundtrip.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import hypothesis.extra.pandas as pdst # isort:skip
1616
import hypothesis.strategies as st # isort:skip
1717
from hypothesis import given # isort:skip
18+
from xarray.tests import has_pyarrow
1819

1920
numeric_dtypes = st.one_of(
2021
npst.unsigned_integer_dtypes(endianness="="),
@@ -134,10 +135,39 @@ def test_roundtrip_pandas_dataframe_datetime(df) -> None:
134135
xr.testing.assert_identical(dataset, roundtripped.to_xarray())
135136

136137

137-
def test_roundtrip_1d_pandas_extension_array() -> None:
138-
df = pd.DataFrame({"cat": pd.Categorical(["a", "b", "c"])})
139-
arr = xr.Dataset.from_dataframe(df)["cat"]
138+
@pytest.mark.parametrize(
139+
"extension_array",
140+
[
141+
pd.Categorical(["a", "b", "c"]),
142+
pd.array(["a", "b", "c"], dtype="string"),
143+
pd.arrays.IntervalArray(
144+
[pd.Interval(0, 1), pd.Interval(1, 5), pd.Interval(2, 6)]
145+
),
146+
pd.arrays.TimedeltaArray._from_sequence(pd.TimedeltaIndex(["1h", "2h", "3h"])),
147+
pd.arrays.DatetimeArray._from_sequence(
148+
pd.DatetimeIndex(["2023-01-01", "2023-01-02", "2023-01-03"], freq="D")
149+
),
150+
np.array([1, 2, 3], dtype="int64"),
151+
]
152+
+ ([pd.array([1, 2, 3], dtype="int64[pyarrow]")] if has_pyarrow else []),
153+
ids=["cat", "string", "interval", "timedelta", "datetime", "numpy"]
154+
+ (["pyarrow"] if has_pyarrow else []),
155+
)
156+
@pytest.mark.parametrize("is_index", [True, False])
157+
def test_roundtrip_1d_pandas_extension_array(extension_array, is_index) -> None:
158+
df = pd.DataFrame({"arr": extension_array})
159+
if is_index:
160+
df = df.set_index("arr")
161+
arr = xr.Dataset.from_dataframe(df)["arr"]
140162
roundtripped = arr.to_pandas()
141-
assert (df["cat"] == roundtripped).all()
142-
assert df["cat"].dtype == roundtripped.dtype
143-
xr.testing.assert_identical(arr, roundtripped.to_xarray())
163+
df_arr_to_test = df.index if is_index else df["arr"]
164+
assert (df_arr_to_test == roundtripped).all()
165+
# `NumpyExtensionArray` types are not roundtripped, including `StringArray` which subtypes.
166+
if isinstance(extension_array, pd.arrays.NumpyExtensionArray): # type: ignore[attr-defined]
167+
assert isinstance(arr.data, np.ndarray)
168+
else:
169+
assert (
170+
df_arr_to_test.dtype
171+
== (roundtripped.index if is_index else roundtripped).dtype
172+
)
173+
xr.testing.assert_identical(arr, roundtripped.to_xarray())

xarray/core/dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@
9999
parse_dims_as_set,
100100
)
101101
from xarray.core.variable import (
102+
UNSUPPORTED_EXTENSION_ARRAY_TYPES,
102103
IndexVariable,
103104
Variable,
104105
as_variable,
@@ -7281,7 +7282,7 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
72817282
extension_arrays = []
72827283
for k, v in dataframe.items():
72837284
if not is_extension_array_dtype(v) or isinstance(
7284-
v.array, pd.arrays.DatetimeArray | pd.arrays.TimedeltaArray
7285+
v.array, UNSUPPORTED_EXTENSION_ARRAY_TYPES
72857286
):
72867287
arrays.append((k, np.asarray(v)))
72877288
else:

xarray/core/extension_array.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,13 @@ class PandasExtensionArray(Generic[T_ExtensionArray], NDArrayMixin):
9393
def __post_init__(self):
9494
if not isinstance(self.array, pd.api.extensions.ExtensionArray):
9595
raise TypeError(f"{self.array} is not an pandas ExtensionArray.")
96+
# This does not use the UNSUPPORTED_EXTENSION_ARRAY_TYPES whitelist because
97+
# we do support extension arrays from datetime, for example, that need
98+
# duck array support internally via this class.
99+
if isinstance(self.array, pd.arrays.NumpyExtensionArray):
100+
raise TypeError(
101+
"`NumpyExtensionArray` should be converted to a numpy array in `xarray` internally."
102+
)
96103

97104
def __array_function__(self, func, types, args, kwargs):
98105
def replace_duck_with_extension_array(args) -> list:

xarray/core/indexing.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1802,8 +1802,12 @@ def __array__(
18021802

18031803
def get_duck_array(self) -> np.ndarray | PandasExtensionArray:
18041804
# We return an PandasExtensionArray wrapper type that satisfies
1805-
# duck array protocols. This is what's needed for tests to pass.
1806-
if pd.api.types.is_extension_array_dtype(self.array):
1805+
# duck array protocols.
1806+
# `NumpyExtensionArray` is excluded
1807+
if pd.api.types.is_extension_array_dtype(self.array) and not isinstance(
1808+
self.array.array,
1809+
pd.arrays.NumpyExtensionArray, # type: ignore[attr-defined]
1810+
):
18071811
from xarray.core.extension_array import PandasExtensionArray
18081812

18091813
return PandasExtensionArray(self.array.array)

xarray/core/variable.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@
6363
)
6464
# https://github.com/python/mypy/issues/224
6565
BASIC_INDEXING_TYPES = integer_types + (slice,)
66+
UNSUPPORTED_EXTENSION_ARRAY_TYPES = (
67+
pd.arrays.DatetimeArray,
68+
pd.arrays.TimedeltaArray,
69+
pd.arrays.NumpyExtensionArray, # type: ignore[attr-defined]
70+
)
6671

6772
if TYPE_CHECKING:
6873
from xarray.core.types import (
@@ -190,6 +195,8 @@ def _maybe_wrap_data(data):
190195
"""
191196
if isinstance(data, pd.Index):
192197
return PandasIndexingAdapter(data)
198+
if isinstance(data, UNSUPPORTED_EXTENSION_ARRAY_TYPES):
199+
return data.to_numpy()
193200
if isinstance(data, pd.api.extensions.ExtensionArray):
194201
return PandasExtensionArray(data)
195202
return data
@@ -251,7 +258,14 @@ def convert_non_numpy_type(data):
251258

252259
# we don't want nested self-described arrays
253260
if isinstance(data, pd.Series | pd.DataFrame):
254-
pandas_data = data.values
261+
if (
262+
isinstance(data, pd.Series)
263+
and pd.api.types.is_extension_array_dtype(data)
264+
and not isinstance(data.array, UNSUPPORTED_EXTENSION_ARRAY_TYPES)
265+
):
266+
pandas_data = data.array
267+
else:
268+
pandas_data = data.values # type: ignore[assignment]
255269
if isinstance(pandas_data, NON_NUMPY_SUPPORTED_ARRAY_TYPES):
256270
return convert_non_numpy_type(pandas_data)
257271
else:

0 commit comments

Comments
 (0)