Skip to content

Feature/153 support polars in pin_write to parquet #263

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,10 @@ board.pin_write(mtcars.head(), "mtcars", type="csv")

Above, we saved the data as a CSV, but depending on what you’re saving
and who else you want to read it, you might use the `type` argument to
instead save it as a `joblib`, `parquet`, or `json` file.
instead save it as a `joblib`, `parquet`, or `json` file. If you're using
a `polars.DataFrame`, you can save to `parquet`.

You can later retrieve the pinned data with `.pin_read()`:
You can later retrieve the pinned data as a `pandas.DataFrame` with `.pin_read()`:

``` python
board.pin_read("mtcars")
Expand Down
3 changes: 2 additions & 1 deletion README.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ board.pin_write(mtcars.head(), "mtcars", type="csv")
Above, we saved the data as a CSV, but depending on
what you’re saving and who else you want to read it, you might use the
`type` argument to instead save it as a `joblib`, `parquet`, or `json` file.
If you're using a `polars.DataFrame`, you can save to `parquet`.

You can later retrieve the pinned data with `.pin_read()`:
You can later retrieve the pinned data as a `pandas.DataFrame` with `.pin_read()`:

```{python}
board.pin_read("mtcars")
Expand Down
124 changes: 100 additions & 24 deletions pins/drivers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Sequence
from typing import Literal, Sequence, TypeAlias

from .config import PINS_ENV_INSECURE_READ, get_allow_pickle_read
from .errors import PinsInsecureReadError
Expand All @@ -11,15 +11,7 @@

UNSAFE_TYPES = frozenset(["joblib"])
REQUIRES_SINGLE_FILE = frozenset(["csv", "joblib", "file"])


def _assert_is_pandas_df(x, file_type: str) -> None:
import pandas as pd

if not isinstance(x, pd.DataFrame):
raise NotImplementedError(
f"Currently only pandas.DataFrame can be saved as type {file_type!r}."
)
_DFLib: TypeAlias = Literal["pandas", "polars"]


def load_path(meta, path_to_version):
Expand Down Expand Up @@ -152,28 +144,31 @@ def save_data(obj, fname, type=None, apply_suffix: bool = True) -> "str | Sequen
final_name = f"{fname}{suffix}"

if type == "csv":
_assert_is_pandas_df(obj, file_type=type)

_choose_df_lib(obj, supported_libs=["pandas"], file_type=type)
obj.to_csv(final_name, index=False)

elif type == "arrow":
# NOTE: R pins accepts the type arrow, and saves it as feather.
# we allow reading this type, but raise an error for writing.
_assert_is_pandas_df(obj, file_type=type)

_choose_df_lib(obj, supported_libs=["pandas"], file_type=type)
obj.to_feather(final_name)

elif type == "feather":
_assert_is_pandas_df(obj, file_type=type)
_choose_df_lib(obj, supported_libs=["pandas"], file_type=type)

raise NotImplementedError(
'Saving data as type "feather" no longer supported. Use type "arrow" instead.'
)

elif type == "parquet":
_assert_is_pandas_df(obj, file_type=type)
df_lib = _choose_df_lib(obj, supported_libs=["pandas", "polars"], file_type=type)

obj.to_parquet(final_name)
if df_lib == "pandas":
obj.to_parquet(final_name)
elif df_lib == "polars":
obj.write_parquet(final_name)
else:
raise NotImplementedError

elif type == "joblib":
import joblib
Expand All @@ -200,13 +195,94 @@ def save_data(obj, fname, type=None, apply_suffix: bool = True) -> "str | Sequen


def default_title(obj, name):
try:
df_lib = _choose_df_lib(obj)
except NotImplementedError:
obj_name = type(obj).__qualname__
return f"{name}: a pinned {obj_name} object"

_df_lib_to_objname: dict[_DFLib, str] = {
"polars": "DataFrame",
"pandas": "DataFrame",
}

# TODO(compat): title says CSV rather than data.frame
# see https://github.com/machow/pins-python/issues/5
shape_str = " x ".join(map(str, obj.shape))
return f"{name}: a pinned {shape_str} {_df_lib_to_objname[df_lib]}"


def _choose_df_lib(
df,
*,
supported_libs: list[_DFLib] | None = None,
file_type: str | None = None,
) -> _DFLib:
"""Return the library associated with a DataFrame, e.g. "pandas".

The arguments `supported_libs` and `file_type` must be specified together, and are
meant to be used when saving an object, to choose the appropriate library.

Args:
df:
The object to check - might not be a DataFrame necessarily.
supported_libs:
The DataFrame libraries to accept for this df.
file_type:
The file type we're trying to save to - used to give more specific error
messages.

Raises:
NotImplementedError: If the DataFrame type is not recognized, or not supported.
"""
if (supported_libs is None) + (file_type is None) == 1:
raise ValueError("Must provide both or neither of supported_libs and file_type")

df_libs: list[_DFLib] = []

# pandas
import pandas as pd

if isinstance(obj, pd.DataFrame):
# TODO(compat): title says CSV rather than data.frame
# see https://github.com/machow/pins-python/issues/5
shape_str = " x ".join(map(str, obj.shape))
return f"{name}: a pinned {shape_str} DataFrame"
if isinstance(df, pd.DataFrame):
df_libs.append("pandas")

# polars
try:
import polars as pl
except ModuleNotFoundError:
pass
else:
obj_name = type(obj).__qualname__
return f"{name}: a pinned {obj_name} object"
if isinstance(df, pl.DataFrame):
df_libs.append("polars")

# Make sure there's only one library associated with the dataframe
if len(df_libs) == 1:
(df_lib,) = df_libs
elif len(df_libs) > 1:
msg = (
f"Hybrid DataFrames are not supported: "
f"should only be one of {supported_libs!r}, "
f"but got an object from multiple libraries {df_libs!r}."
)
raise NotImplementedError(msg)
else:
raise NotImplementedError(f"Unrecognized DataFrame type: {type(df)}")

# Raise if the library is not supported
if supported_libs is not None and df_lib not in supported_libs:
ftype_clause = f"for type {file_type!r}"

if len(supported_libs) == 1:
msg = (
f"Currently only {supported_libs[0]} DataFrames can be saved "
f"{ftype_clause}. DataFrames from {df_lib} are not yet supported."
)
else:
msg = (
f"Currently only DataFrames from the following libraries can be saved "
f"{ftype_clause}: {supported_libs!r}."
)

raise NotImplementedError(msg)

return df_lib
88 changes: 87 additions & 1 deletion pins/tests/test_drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import fsspec
import pandas as pd
import polars as pl
import pytest

from pins.config import PINS_ENV_INSECURE_READ
from pins.drivers import default_title, load_data, save_data
from pins.drivers import _choose_df_lib, default_title, load_data, save_data
from pins.errors import PinsInsecureReadError
from pins.meta import MetaRaw
from pins.tests.helpers import rm_env
Expand Down Expand Up @@ -34,6 +35,7 @@ class D:
[
(pd.DataFrame({"x": [1, 2]}), "somename: a pinned 2 x 1 DataFrame"),
(pd.DataFrame({"x": [1], "y": [2]}), "somename: a pinned 1 x 2 DataFrame"),
(pl.DataFrame({"x": [1, 2]}), "somename: a pinned 2 x 1 DataFrame"),
(ExC(), "somename: a pinned ExC object"),
(ExC().D(), "somename: a pinned ExC.D object"),
([1, 2, 3], "somename: a pinned list object"),
Expand Down Expand Up @@ -76,6 +78,36 @@ def test_driver_roundtrip(tmp_path: Path, type_):
assert df.equals(obj)


@pytest.mark.parametrize(
"type_",
[
"parquet",
],
)
def test_driver_polars_roundtrip(tmp_path, type_):
import polars as pl

df = pl.DataFrame({"x": [1, 2, 3]})

fname = "some_df"
full_file = f"{fname}.{type_}"

p_obj = tmp_path / fname
res_fname = save_data(df, p_obj, type_)

assert Path(res_fname).name == full_file

meta = MetaRaw(full_file, type_, "my_pin")
pandas_df = load_data(
meta, fsspec.filesystem("file"), tmp_path, allow_pickle_read=True
)

# Convert from pandas to polars
obj = pl.DataFrame(pandas_df)

assert df.equals(obj)


@pytest.mark.parametrize(
"type_",
[
Expand Down Expand Up @@ -159,3 +191,57 @@ def test_driver_apply_suffix_false(tmp_path: Path):
res_fname = save_data(df, p_obj, type_, apply_suffix=False)

assert Path(res_fname).name == "some_df"


class TestChooseDFLib:
def test_pandas(self):
assert _choose_df_lib(pd.DataFrame({"x": [1]})) == "pandas"

def test_polars(self):
assert _choose_df_lib(pl.DataFrame({"x": [1]})) == "polars"

def test_list_raises(self):
with pytest.raises(
NotImplementedError, match="Unrecognized DataFrame type: <class 'list'>"
):
_choose_df_lib([])

def test_pandas_subclass(self):
class MyDataFrame(pd.DataFrame):
pass

assert _choose_df_lib(MyDataFrame({"x": [1]})) == "pandas"

def test_ftype_compatible(self):
assert (
_choose_df_lib(
pd.DataFrame({"x": [1]}), supported_libs=["pandas"], file_type="csv"
)
== "pandas"
)

def test_ftype_incompatible(self):
with pytest.raises(
NotImplementedError,
match=(
"Currently only pandas DataFrames can be saved for type 'csv'. "
"DataFrames from polars are not yet supported."
),
):
_choose_df_lib(
pl.DataFrame({"x": [1]}), supported_libs=["pandas"], file_type="csv"
)

def test_supported_alone_raises(self):
with pytest.raises(
ValueError,
match="Must provide both or neither of supported_libs and file_type",
):
_choose_df_lib(..., supported_libs=["pandas"])

def test_file_type_alone_raises(self):
with pytest.raises(
ValueError,
match="Must provide both or neither of supported_libs and file_type",
):
_choose_df_lib(..., file_type="csv")
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ test = [
"pytest-dotenv",
"pytest-parallel",
"s3fs",
"polars>=1.0.0",
]

[build-system]
Expand Down
2 changes: 2 additions & 0 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,8 @@ pluggy==1.5.0
# via pytest
plum-dispatch==2.5.1.post1
# via quartodoc
polars==1.2.1
# via pins (setup.cfg)
portalocker==2.10.1
# via msal-extensions
pre-commit==3.7.1
Expand Down
Loading