Skip to content

Commit d682808

Browse files
Write more robust df-library choosing logic.
1 parent 62bffaf commit d682808

File tree

1 file changed

+82
-49
lines changed

1 file changed

+82
-49
lines changed

pins/drivers.py

+82-49
Original file line numberDiff line numberDiff line change
@@ -4,47 +4,15 @@
44
from .meta import Meta
55
from .errors import PinsInsecureReadError
66

7-
from typing import Literal, Sequence
7+
from typing import Literal, Sequence, TypeAlias
88

99
# TODO: move IFileSystem out of boards, to fix circular import
1010
# from .boards import IFileSystem
1111

1212

1313
UNSAFE_TYPES = frozenset(["joblib"])
1414
REQUIRES_SINGLE_FILE = frozenset(["csv", "joblib", "file"])
15-
16-
17-
def _assert_is_pandas_df(x, file_type: str) -> None:
18-
df_family = _get_df_family(x)
19-
20-
if df_family != "pandas":
21-
raise NotImplementedError(
22-
f"Currently only pandas.DataFrame can be saved as type {file_type!r}."
23-
)
24-
25-
26-
def _get_df_family(df) -> Literal["pandas", "polars"]:
27-
"""Return the type of DataFrame, or raise NotImplementedError if we can't decide."""
28-
try:
29-
import polars as pl
30-
except ModuleNotFoundError:
31-
is_polars_df = False
32-
else:
33-
is_polars_df = isinstance(df, pl.DataFrame)
34-
35-
import pandas as pd
36-
37-
is_pandas_df = isinstance(df, pd.DataFrame)
38-
39-
if is_polars_df and is_pandas_df:
40-
raise NotImplementedError(
41-
"Hybrid DataFrames (simultaneously pandas and polars) are not supported."
42-
)
43-
elif is_polars_df:
44-
return "polars"
45-
elif is_pandas_df:
46-
return "pandas"
47-
raise NotImplementedError(f"Unrecognized DataFrame type: {type(df)}")
15+
_DFLib: TypeAlias = Literal["pandas", "polars"]
4816

4917

5018
def load_path(meta, path_to_version):
@@ -177,36 +145,31 @@ def save_data(obj, fname, type=None, apply_suffix: bool = True) -> "str | Sequen
177145
final_name = f"{fname}{suffix}"
178146

179147
if type == "csv":
180-
_assert_is_pandas_df(obj, file_type=type)
181-
148+
_choose_df_lib(obj, supported_libs=["pandas"], file_type=type)
182149
obj.to_csv(final_name, index=False)
183150

184151
elif type == "arrow":
185152
# NOTE: R pins accepts the type arrow, and saves it as feather.
186153
# we allow reading this type, but raise an error for writing.
187-
_assert_is_pandas_df(obj, file_type=type)
188-
154+
_choose_df_lib(obj, supported_libs=["pandas"], file_type=type)
189155
obj.to_feather(final_name)
190156

191157
elif type == "feather":
192-
_assert_is_pandas_df(obj, file_type=type)
158+
_choose_df_lib(obj, supported_libs=["pandas"], file_type=type)
193159

194160
raise NotImplementedError(
195161
'Saving data as type "feather" no longer supported. Use type "arrow" instead.'
196162
)
197163

198164
elif type == "parquet":
199-
df_family = _get_df_family(obj)
200-
if df_family == "polars":
201-
obj.write_parquet(final_name)
202-
elif df_family == "pandas":
165+
df_lib = _choose_df_lib(obj, supported_libs=["pandas", "polars"], file_type=type)
166+
167+
if df_lib == "pandas":
203168
obj.to_parquet(final_name)
169+
elif df_lib == "polars":
170+
obj.write_parquet(final_name)
204171
else:
205-
msg = (
206-
"Currently only pandas.DataFrame and polars.DataFrame can be saved to "
207-
"a parquet file."
208-
)
209-
raise NotImplementedError(msg)
172+
raise NotImplementedError
210173

211174
elif type == "joblib":
212175
import joblib
@@ -234,7 +197,7 @@ def save_data(obj, fname, type=None, apply_suffix: bool = True) -> "str | Sequen
234197

235198
def default_title(obj, name):
236199
try:
237-
_get_df_family(obj)
200+
_choose_df_lib(obj)
238201
except NotImplementedError:
239202
obj_name = type(obj).__qualname__
240203
return f"{name}: a pinned {obj_name} object"
@@ -243,3 +206,73 @@ def default_title(obj, name):
243206
# see https://github.com/machow/pins-python/issues/5
244207
shape_str = " x ".join(map(str, obj.shape))
245208
return f"{name}: a pinned {shape_str} DataFrame"
209+
210+
211+
def _choose_df_lib(
212+
df,
213+
*,
214+
supported_libs: list[_DFLib] = ["pandas", "polars"],
215+
file_type: str | None = None,
216+
) -> _DFLib:
217+
"""Return the type of DataFrame library used in the given DataFrame.
218+
219+
Args:
220+
df:
221+
The object to check - might not be a DataFrame necessarily.
222+
supported_libs:
223+
The DataFrame libraries to accept for this df.
224+
file_type:
225+
The file type we're trying to save to - used to give more specific error messages.
226+
227+
Raises:
228+
NotImplementedError: If the DataFrame type is not recognized.
229+
"""
230+
df_libs: list[_DFLib] = []
231+
232+
# pandas
233+
import pandas as pd
234+
235+
if isinstance(df, pd.DataFrame):
236+
df_libs.append("pandas")
237+
238+
# polars
239+
try:
240+
import polars as pl
241+
except ModuleNotFoundError:
242+
pass
243+
else:
244+
if isinstance(df, pl.DataFrame):
245+
df_libs.append("polars")
246+
247+
if len(df_libs) == 1:
248+
(df_lib,) = df_libs
249+
elif len(df_libs) > 1:
250+
msg = (
251+
f"Hybrid DataFrames are not supported: "
252+
f"should only be one of {supported_libs!r}, "
253+
f"but got an object from multiple libraries {df_libs!r}."
254+
)
255+
raise NotImplementedError(msg)
256+
else:
257+
raise NotImplementedError(f"Unrecognized DataFrame type: {type(df)}")
258+
259+
if df_lib not in supported_libs:
260+
if file_type is None:
261+
ftype_clause = "in pins"
262+
else:
263+
ftype_clause = f"for type {file_type!r}"
264+
265+
if len(supported_libs) == 1:
266+
msg = (
267+
f"Currently only {supported_libs[0]} DataFrames can be saved "
268+
f"{ftype_clause}. {df_lib} DataFrames are not yet supported."
269+
)
270+
else:
271+
msg = (
272+
f"Currently only DataFrames from the following libraries can be saved "
273+
f"{ftype_clause}: {supported_libs!r}."
274+
)
275+
276+
raise NotImplementedError(msg)
277+
278+
return df_lib

0 commit comments

Comments
 (0)