diff --git a/CHANGELOG.md b/CHANGELOG.md index faf3d4fd8..7d9febfb7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,9 @@ ## [Unreleased] - [ENH] Added `row_count` parameter for janitor.conditional_join - Issue #1269 @samukweku -- [ENG] Reverse deprecation of `pivot_wider()` -- Issue #1464 +- [ENH] Reverse deprecation of `pivot_wider()` -- Issue #1464 +- [ENH] Add accessor and method for pandas DataFrameGroupBy objects. - Issue #587 @samukweku + ## [v0.31.0] - 2025-03-07 - [ENH] Added support for pd.Series.select - Issue #1394 @samukweku diff --git a/janitor/functions/select.py b/janitor/functions/select.py index 0c076fcfb..9e7f4136d 100644 --- a/janitor/functions/select.py +++ b/janitor/functions/select.py @@ -327,17 +327,18 @@ def select_rows( return _select(df, rows=list(args), invert=invert) +@pf.register_groupby_method @pf.register_dataframe_method @pf.register_series_method @deprecated_alias(rows="index") def select( - df: pd.DataFrame | pd.Series, + df: pd.DataFrame | pd.Series | DataFrameGroupBy, *args: tuple, index: Any = None, columns: Any = None, axis: str = "columns", invert: bool = False, -) -> pd.DataFrame | pd.Series: +) -> pd.DataFrame | pd.Series | DataFrameGroupBy: """Method-chainable selection of rows and/or columns. It accepts a string, shell-like glob strings `(*string*)`, @@ -371,6 +372,8 @@ def select( - `rows` keyword deprecated in favour of `index`. - 0.31.0 - Add support for pd.Series. + - 0.32.0 + - Add support for DataFrameGroupBy. Examples: >>> import pandas as pd @@ -436,6 +439,10 @@ def select( Returns: A pandas DataFrame or Series with the specified rows and/or columns selected. """ # noqa: E501 + if args and isinstance(df, DataFrameGroupBy): + return get_columns(group=df, label=list(args)) + if isinstance(df, DataFrameGroupBy): + return get_columns(group=df, label=[columns]) if args: check("invert", invert, [bool]) if (index is not None) or (columns is not None): @@ -478,6 +485,12 @@ def get_index_labels( return index[_select_index(arg, df, axis)] +@refactored_function( + message=( + "This function will be deprecated in a 1.x release. " + "Please use `jn.select` instead." + ) +) def get_columns( group: DataFrameGroupBy | SeriesGroupBy, label: Any ) -> DataFrameGroupBy | SeriesGroupBy: @@ -488,6 +501,11 @@ def get_columns( !!! info "New in version 0.25.0" + !!!note + + This function will be deprecated in a 1.x release. + Please use `jn.select` instead. + Args: group: A Pandas GroupBy object. label: column(s) to select. diff --git a/tests/functions/test_select_columns.py b/tests/functions/test_select_columns.py index 1444a7117..bea545467 100644 --- a/tests/functions/test_select_columns.py +++ b/tests/functions/test_select_columns.py @@ -479,6 +479,20 @@ def test_select_groupby(dataframe): assert_frame_equal(expected, actual) +def test_select_groupby_args(dataframe): + """Test output on a grouped object""" + expected = dataframe.select_dtypes("number").groupby(dataframe["a"]).sum() + actual = dataframe.groupby("a").select(is_numeric_dtype).sum() + assert_frame_equal(expected, actual) + + +def test_select_groupby_columns(dataframe): + """Test output on a grouped object""" + expected = dataframe.select_dtypes("number").groupby(dataframe["a"]).sum() + actual = dataframe.groupby("a").select(columns=is_numeric_dtype).sum() + assert_frame_equal(expected, actual) + + def test_select_str_multiindex(multiindex): """Test str selection on a MultiIndex - exact match""" expected = multiindex.select_columns("bar")