diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index c404fc74..7ee8e3f9 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -1013,7 +1013,12 @@ TimeZones: TypeAlias = str | tzinfo | None | int # Evaluates to a DataFrame column in DataFrame.assign context. IntoColumn: TypeAlias = ( - AnyArrayLike | Scalar | Callable[[DataFrame], AnyArrayLike | Scalar] | None + AnyArrayLike + | Scalar + | Callable[[DataFrame], AnyArrayLike | Scalar | list[Scalar] | range] + | list[Scalar] + | range + | None ) DatetimeLike: TypeAlias = datetime.datetime | np.datetime64 | Timestamp diff --git a/pandas-stubs/core/indexes/multi.pyi b/pandas-stubs/core/indexes/multi.pyi index 50652020..f0b00ec3 100644 --- a/pandas-stubs/core/indexes/multi.pyi +++ b/pandas-stubs/core/indexes/multi.pyi @@ -55,7 +55,7 @@ class MultiIndex(Index[Any]): @classmethod def from_product( cls, - iterables: Sequence[SequenceNotStr[Hashable]], + iterables: Sequence[SequenceNotStr[Hashable] | pd.Series | pd.Index], sortorder: int | None = ..., names: SequenceNotStr[Hashable] = ..., ) -> Self: ... diff --git a/tests/test_frame.py b/tests/test_frame.py index eb0da931..094a4546 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -45,7 +45,9 @@ ) import xarray as xr -from pandas._typing import Scalar +from pandas._typing import ( + Scalar, +) from tests import ( PD_LTE_23, @@ -305,9 +307,26 @@ def test_types_head_tail() -> None: def test_types_assign() -> None: df = pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]}) - df.assign(col3=lambda frame: frame.sum(axis=1)) + + check( + assert_type(df.assign(col3=lambda frame: frame.sum(axis=1)), pd.DataFrame), + pd.DataFrame, + ) df["col3"] = df.sum(axis=1) + df = pd.DataFrame({"a": [1, 2, 3]}) + check( + assert_type( + df.assign(b=lambda df: range(len(df)), c=lambda _: [10, 20, 30]), + pd.DataFrame, + ), + pd.DataFrame, + ) + check( + assert_type(df.assign(b=range(len(df)), c=[10, 20, 30]), pd.DataFrame), + pd.DataFrame, + ) + def test_assign() -> None: df = pd.DataFrame({"a": [1, 2, 3], 1: [4, 5, 6]}) diff --git a/tests/test_indexes.py b/tests/test_indexes.py index 14bbc762..2b448075 100644 --- a/tests/test_indexes.py +++ b/tests/test_indexes.py @@ -61,6 +61,20 @@ def test_index_astype() -> None: pd.DataFrame, ) + df = pd.DataFrame({"a": [1, 2, 3]}) + check( + assert_type( + pd.MultiIndex.from_product([["x", "y"], df.columns]), pd.MultiIndex + ), + pd.MultiIndex, + ) + check( + assert_type( + pd.MultiIndex.from_product([["x", "y"], pd.Series([1, 2])]), pd.MultiIndex + ), + pd.MultiIndex, + ) + def test_multiindex_get_level_values() -> None: mi = pd.MultiIndex.from_product([["a", "b"], ["c", "d"]], names=["ab", "cd"])