Skip to content

Commit d3ce8be

Browse files
Add support for list type in get (#20332)
Closes: #20326 There are instances where we result in `list` type in `StringMethods`, for which pandas returns `object` dtype. But calling `.str.xyz` on that result is technically allowed and valid for some APIs. This is pr enables `list` type for `.str.get` API. Authors: - GALI PREM SAGAR (https://github.com/galipremsagar) Approvers: - Matthew Roeschke (https://github.com/mroeschke) URL: #20332
1 parent c8a326b commit d3ce8be

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

python/cudf/cudf/core/accessors/string.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from cudf.api.extensions import no_default
1717
from cudf.api.types import is_integer, is_scalar
1818
from cudf.core.accessors.base_accessor import BaseAccessor
19+
from cudf.core.accessors.lists import ListMethods
1920
from cudf.core.column.column import ColumnBase, as_column, column_empty
2021
from cudf.core.dtypes import ListDtype
2122
from cudf.options import get_option
@@ -2313,6 +2314,8 @@ def get(self, i: int = 0) -> Series | Index:
23132314
2 f
23142315
dtype: object
23152316
"""
2317+
if isinstance(self._column.dtype, ListDtype):
2318+
return ListMethods(self._parent).get(i)
23162319
str_lens = self.len()
23172320
if i < 0:
23182321
next_index = i - 1

python/cudf/cudf/tests/series/accessors/test_str.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2844,3 +2844,18 @@ def test_string_misc_name(ps_gs, name):
28442844
assert_eq(ps + ps, gs + gs)
28452845
assert_eq(ps + "RAPIDS", gs + "RAPIDS")
28462846
assert_eq("RAPIDS" + ps, "RAPIDS" + gs)
2847+
2848+
2849+
def test_string_list_get_access():
2850+
ps = pd.Series(["a,b,c", "d,e,f", None, "g,h,i"])
2851+
gs = cudf.from_pandas(ps)
2852+
2853+
expect = ps.str.split(",")
2854+
got = gs.str.split(",")
2855+
2856+
assert_eq(expect, got)
2857+
2858+
expect = expect.str.get(1)
2859+
got = got.str.get(1)
2860+
2861+
assert_eq(expect, got)

0 commit comments

Comments
 (0)