From 275df316d68ccfc45fb1c81e187716951875d6ad Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Wed, 15 Oct 2025 22:30:01 +0000 Subject: [PATCH] feat: Add df.sort_index(axis=1) --- bigframes/dataframe.py | 32 ++++++++++++------- tests/system/small/test_dataframe.py | 12 +++++-- tests/unit/test_dataframe_polars.py | 12 +++++-- .../bigframes_vendored/pandas/core/frame.py | 4 +++ 4 files changed, 42 insertions(+), 18 deletions(-) diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index bc2bbb963b..f8aa7ec9be 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -2555,25 +2555,33 @@ def sort_index( ) -> None: ... - @validations.requires_index def sort_index( self, *, + axis: Union[int, str] = 0, ascending: bool = True, inplace: bool = False, na_position: Literal["first", "last"] = "last", ) -> Optional[DataFrame]: - if na_position not in ["first", "last"]: - raise ValueError("Param na_position must be one of 'first' or 'last'") - na_last = na_position == "last" - index_columns = self._block.index_columns - ordering = [ - order.ascending_over(column, na_last) - if ascending - else order.descending_over(column, na_last) - for column in index_columns - ] - block = self._block.order_by(ordering) + if utils.get_axis_number(axis) == 0: + if na_position not in ["first", "last"]: + raise ValueError("Param na_position must be one of 'first' or 'last'") + na_last = na_position == "last" + index_columns = self._block.index_columns + ordering = [ + order.ascending_over(column, na_last) + if ascending + else order.descending_over(column, na_last) + for column in index_columns + ] + block = self._block.order_by(ordering) + else: # axis=1 + _, indexer = self.columns.sort_values( + return_indexer=True, ascending=ascending, na_position=na_position # type: ignore + ) + block = self._block.select_columns( + [self._block.value_columns[i] for i in indexer] + ) if inplace: self._set_block(block) return None diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index 1e6151b7f4..34bb5a4fb3 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -2406,13 +2406,19 @@ def test_set_index_key_error(scalars_dfs): ("na_position",), (("first",), ("last",)), ) -def test_sort_index(scalars_dfs, ascending, na_position): +@pytest.mark.parametrize( + ("axis",), + ((0,), ("columns",)), +) +def test_sort_index(scalars_dfs, ascending, na_position, axis): index_column = "int64_col" scalars_df, scalars_pandas_df = scalars_dfs df = scalars_df.set_index(index_column) - bf_result = df.sort_index(ascending=ascending, na_position=na_position).to_pandas() + bf_result = df.sort_index( + ascending=ascending, na_position=na_position, axis=axis + ).to_pandas() pd_result = scalars_pandas_df.set_index(index_column).sort_index( - ascending=ascending, na_position=na_position + ascending=ascending, na_position=na_position, axis=axis ) pandas.testing.assert_frame_equal(bf_result, pd_result) diff --git a/tests/unit/test_dataframe_polars.py b/tests/unit/test_dataframe_polars.py index a6f5c3d1ef..b83380d789 100644 --- a/tests/unit/test_dataframe_polars.py +++ b/tests/unit/test_dataframe_polars.py @@ -1757,13 +1757,19 @@ def test_set_index_key_error(scalars_dfs): ("na_position",), (("first",), ("last",)), ) -def test_sort_index(scalars_dfs, ascending, na_position): +@pytest.mark.parametrize( + ("axis",), + ((0,), ("columns",)), +) +def test_sort_index(scalars_dfs, ascending, na_position, axis): index_column = "int64_col" scalars_df, scalars_pandas_df = scalars_dfs df = scalars_df.set_index(index_column) - bf_result = df.sort_index(ascending=ascending, na_position=na_position).to_pandas() + bf_result = df.sort_index( + ascending=ascending, na_position=na_position, axis=axis + ).to_pandas() pd_result = scalars_pandas_df.set_index(index_column).sort_index( - ascending=ascending, na_position=na_position + ascending=ascending, na_position=na_position, axis=axis ) pandas.testing.assert_frame_equal(bf_result, pd_result) diff --git a/third_party/bigframes_vendored/pandas/core/frame.py b/third_party/bigframes_vendored/pandas/core/frame.py index 557c332797..99733c7a3e 100644 --- a/third_party/bigframes_vendored/pandas/core/frame.py +++ b/third_party/bigframes_vendored/pandas/core/frame.py @@ -2382,6 +2382,7 @@ def sort_values( def sort_index( self, *, + axis: str | int = 0, ascending: bool = True, inplace: bool = False, na_position: Literal["first", "last"] = "last", @@ -2389,6 +2390,9 @@ def sort_index( """Sort object by labels (along an axis). Args: + axis ({0 or 'index', 1 or 'columns'}, default 0): + The axis along which to sort. The value 0 identifies the rows, + and 1 identifies the columns. ascending (bool, default True) Sort ascending vs. descending. inplace (bool, default False):