Skip to content

Commit 71e58c9

Browse files
Refactoring and adding tests.
1 parent 8294e47 commit 71e58c9

File tree

2 files changed

+78
-13
lines changed

2 files changed

+78
-13
lines changed

pins/drivers.py

+23-12
Original file line numberDiff line numberDiff line change
@@ -197,36 +197,48 @@ def save_data(obj, fname, type=None, apply_suffix: bool = True) -> "str | Sequen
197197

198198
def default_title(obj, name):
199199
try:
200-
_choose_df_lib(obj)
200+
df_lib = _choose_df_lib(obj)
201201
except NotImplementedError:
202202
obj_name = type(obj).__qualname__
203203
return f"{name}: a pinned {obj_name} object"
204204

205+
_df_lib_to_objname: dict[_DFLib, str] = {
206+
"polars": "DataFrame",
207+
"pandas": "DataFrame",
208+
}
209+
205210
# TODO(compat): title says CSV rather than data.frame
206211
# see https://github.com/machow/pins-python/issues/5
207212
shape_str = " x ".join(map(str, obj.shape))
208-
return f"{name}: a pinned {shape_str} DataFrame"
213+
return f"{name}: a pinned {shape_str} {_df_lib_to_objname[df_lib]}"
209214

210215

211216
def _choose_df_lib(
212217
df,
213218
*,
214-
supported_libs: list[_DFLib] = ["pandas", "polars"],
219+
supported_libs: list[_DFLib] | None = None,
215220
file_type: str | None = None,
216221
) -> _DFLib:
217-
"""Return the type of DataFrame library used in the given DataFrame.
222+
"""Return the library associated with a DataFrame, e.g. "pandas".
223+
224+
The arguments `supported_libs` and `file_type` must be specified together, and are
225+
meant to be used when saving an object, to choose the appropriate library.
218226
219227
Args:
220228
df:
221229
The object to check - might not be a DataFrame necessarily.
222230
supported_libs:
223231
The DataFrame libraries to accept for this df.
224232
file_type:
225-
The file type we're trying to save to - used to give more specific error messages.
233+
The file type we're trying to save to - used to give more specific error
234+
messages.
226235
227236
Raises:
228-
NotImplementedError: If the DataFrame type is not recognized.
237+
NotImplementedError: If the DataFrame type is not recognized, or not supported.
229238
"""
239+
if (supported_libs is None) + (file_type is None) == 1:
240+
raise ValueError("Must provide both or neither of supported_libs and file_type")
241+
230242
df_libs: list[_DFLib] = []
231243

232244
# pandas
@@ -244,6 +256,7 @@ def _choose_df_lib(
244256
if isinstance(df, pl.DataFrame):
245257
df_libs.append("polars")
246258

259+
# Make sure there's only one library associated with the dataframe
247260
if len(df_libs) == 1:
248261
(df_lib,) = df_libs
249262
elif len(df_libs) > 1:
@@ -256,16 +269,14 @@ def _choose_df_lib(
256269
else:
257270
raise NotImplementedError(f"Unrecognized DataFrame type: {type(df)}")
258271

259-
if df_lib not in supported_libs:
260-
if file_type is None:
261-
ftype_clause = "in pins"
262-
else:
263-
ftype_clause = f"for type {file_type!r}"
272+
# Raise if the library is not supported
273+
if supported_libs is not None and df_lib not in supported_libs:
274+
ftype_clause = f"for type {file_type!r}"
264275

265276
if len(supported_libs) == 1:
266277
msg = (
267278
f"Currently only {supported_libs[0]} DataFrames can be saved "
268-
f"{ftype_clause}. {df_lib} DataFrames are not yet supported."
279+
f"{ftype_clause}. DataFrames from {df_lib} are not yet supported."
269280
)
270281
else:
271282
msg = (

pins/tests/test_drivers.py

+55-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from pins.meta import MetaRaw
1111
from pins.config import PINS_ENV_INSECURE_READ
12-
from pins.drivers import load_data, save_data, default_title
12+
from pins.drivers import load_data, save_data, default_title, _choose_df_lib
1313
from pins.errors import PinsInsecureReadError
1414

1515

@@ -192,3 +192,57 @@ def test_driver_apply_suffix_false(tmp_path: Path):
192192
res_fname = save_data(df, p_obj, type_, apply_suffix=False)
193193

194194
assert Path(res_fname).name == "some_df"
195+
196+
197+
class TestChooseDFLib:
198+
def test_pandas(self):
199+
assert _choose_df_lib(pd.DataFrame({"x": [1]})) == "pandas"
200+
201+
def test_polars(self):
202+
assert _choose_df_lib(pl.DataFrame({"x": [1]})) == "polars"
203+
204+
def test_list_raises(self):
205+
with pytest.raises(
206+
NotImplementedError, match="Unrecognized DataFrame type: <class 'list'>"
207+
):
208+
_choose_df_lib([])
209+
210+
def test_pandas_subclass(self):
211+
class MyDataFrame(pd.DataFrame):
212+
pass
213+
214+
assert _choose_df_lib(MyDataFrame({"x": [1]})) == "pandas"
215+
216+
def test_ftype_compatible(self):
217+
assert (
218+
_choose_df_lib(
219+
pd.DataFrame({"x": [1]}), supported_libs=["pandas"], file_type="csv"
220+
)
221+
== "pandas"
222+
)
223+
224+
def test_ftype_incompatible(self):
225+
with pytest.raises(
226+
NotImplementedError,
227+
match=(
228+
"Currently only pandas DataFrames can be saved for type 'csv'. "
229+
"DataFrames from polars are not yet supported."
230+
),
231+
):
232+
_choose_df_lib(
233+
pl.DataFrame({"x": [1]}), supported_libs=["pandas"], file_type="csv"
234+
)
235+
236+
def test_supported_alone_raises(self):
237+
with pytest.raises(
238+
ValueError,
239+
match="Must provide both or neither of supported_libs and file_type",
240+
):
241+
_choose_df_lib(..., supported_libs=["pandas"])
242+
243+
def test_file_type_alone_raises(self):
244+
with pytest.raises(
245+
ValueError,
246+
match="Must provide both or neither of supported_libs and file_type",
247+
):
248+
_choose_df_lib(..., file_type="csv")

0 commit comments

Comments
 (0)