Skip to content

Commit d74afaf

Browse files
authored
Add ndarray subclassing support via ndarray.view() method (#2815)
The PR adds ndarray subclassing support via `ndarray.view()` method. It implements the `type` parameter in `dpnp.ndarray.view()` to enable custom subclasses, matching NumPy/CuPy behavior. Also includes proper `__array_finalize__` hook invocation for metadata propagation. The implementation is done through `_view_impl()` helper using 'array_class' parameter to avoid shadowing builtin 'type'. The tests scope is extended with new 7 tests for verifying subclassing support and enabling `TestSubclassArrayView` class from third party tests. This PR closes #2764.
1 parent c8c0f88 commit d74afaf

File tree

4 files changed

+232
-78
lines changed

4 files changed

+232
-78
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ Also, that release drops support for Python 3.9, making Python 3.10 the minimum
2727
* Added implementation of `dpnp.divmod` [#2674](https://github.com/IntelPython/dpnp/pull/2674)
2828
* Added implementation of `dpnp.isin` function [#2595](https://github.com/IntelPython/dpnp/pull/2595)
2929
* Added implementation of `dpnp.scipy.linalg.lu` (SciPy-compatible) [#2787](https://github.com/IntelPython/dpnp/pull/2787)
30+
* Added support for ndarray subclassing via `dpnp.ndarray.view` method with `type` parameter [#2815](https://github.com/IntelPython/dpnp/issues/2815)
3031

3132
### Changed
3233

dpnp/dpnp_array.py

Lines changed: 150 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,136 @@ def _create_from_usm_ndarray(usm_ary: dpt.usm_ndarray):
644644
res._array_obj._set_namespace(dpnp)
645645
return res
646646

647+
def _create_view(self, array_class, shape, dtype, strides):
648+
"""
649+
Create a view of an array with the specified class.
650+
651+
The method handles subclass instantiation by creating a usm_ndarray
652+
view and then wrapping it in the appropriate class.
653+
654+
Parameters
655+
----------
656+
array_class : type
657+
The class to instantiate (dpnp_array or a subclass).
658+
shape : tuple
659+
Shape of the view.
660+
dtype : dtype
661+
Data type of the view (can be None to keep source's dtype).
662+
strides : tuple
663+
Strides of the view.
664+
665+
Returns
666+
-------
667+
view : array_class instance
668+
A view of the array as the specified class.
669+
670+
"""
671+
672+
if dtype is None:
673+
dtype = self.dtype
674+
675+
# create the underlying usm_ndarray view
676+
usm_view = dpt.usm_ndarray(
677+
shape,
678+
dtype=dtype,
679+
buffer=self._array_obj,
680+
strides=tuple(s // dpnp.dtype(dtype).itemsize for s in strides),
681+
)
682+
683+
# wrap the view into the appropriate class
684+
if array_class is dpnp_array:
685+
res = dpnp_array._create_from_usm_ndarray(usm_view)
686+
else:
687+
# for subclasses, create using __new__ and set up manually
688+
res = array_class.__new__(array_class)
689+
res._array_obj = usm_view
690+
res._array_obj._set_namespace(dpnp)
691+
692+
if hasattr(res, "__array_finalize__"):
693+
res.__array_finalize__(self)
694+
695+
return res
696+
697+
def _view_impl(self, dtype=None, array_class=None):
698+
"""
699+
Internal implementation of view method to avoid an issue where
700+
`type` parameter in ndarray.view method shadowing builtin type.
701+
702+
"""
703+
704+
# check if dtype is actually a type
705+
if dtype is not None:
706+
if isinstance(dtype, type) and issubclass(dtype, dpnp_array):
707+
if array_class is not None:
708+
raise ValueError("Cannot specify output type twice")
709+
array_class = dtype
710+
dtype = None
711+
712+
# validate array_class parameter
713+
if not (
714+
array_class is None
715+
or isinstance(array_class, type)
716+
and issubclass(array_class, dpnp_array)
717+
):
718+
raise ValueError("Type must be a sub-type of ndarray type")
719+
720+
if array_class is None:
721+
# it's a view on dpnp.ndarray
722+
array_class = self.__class__
723+
724+
old_sh = self.shape
725+
old_strides = self.strides
726+
727+
if dtype is None:
728+
return self._create_view(array_class, old_sh, None, old_strides)
729+
730+
new_dt = dpnp.dtype(dtype)
731+
new_dt = dtu._to_device_supported_dtype(new_dt, self.sycl_device)
732+
733+
new_itemsz = new_dt.itemsize
734+
old_itemsz = self.dtype.itemsize
735+
if new_itemsz == old_itemsz:
736+
return self._create_view(array_class, old_sh, new_dt, old_strides)
737+
738+
ndim = self.ndim
739+
if ndim == 0:
740+
raise ValueError(
741+
"Changing the dtype of a 0d array is only supported "
742+
"if the itemsize is unchanged"
743+
)
744+
745+
# resize on last axis only
746+
axis = ndim - 1
747+
if (
748+
old_sh[axis] != 1
749+
and self.size != 0
750+
and old_strides[axis] != old_itemsz
751+
):
752+
raise ValueError(
753+
"To change to a dtype of a different size, "
754+
"the last axis must be contiguous"
755+
)
756+
757+
# normalize strides whenever itemsize changes
758+
new_strides = tuple(
759+
old_strides[i] if i != axis else new_itemsz for i in range(ndim)
760+
)
761+
762+
new_dim = old_sh[axis] * old_itemsz
763+
if new_dim % new_itemsz != 0:
764+
raise ValueError(
765+
"When changing to a larger dtype, its size must be a divisor "
766+
"of the total size in bytes of the last axis of the array"
767+
)
768+
769+
# normalize shape whenever itemsize changes
770+
new_sh = tuple(
771+
old_sh[i] if i != axis else new_dim // new_itemsz
772+
for i in range(ndim)
773+
)
774+
775+
return self._create_view(array_class, new_sh, new_dt, new_strides)
776+
647777
def all(self, axis=None, *, out=None, keepdims=False, where=True):
648778
"""
649779
Return ``True`` if all elements evaluate to ``True``.
@@ -2322,10 +2452,18 @@ def view(self, /, dtype=None, *, type=None):
23222452
23232453
Parameters
23242454
----------
2325-
dtype : {None, str, dtype object}, optional
2455+
dtype : {None, str, dtype object, type}, optional
23262456
The desired data type of the returned view, e.g. :obj:`dpnp.float32`
2327-
or :obj:`dpnp.int16`. By default, it results in the view having the
2328-
same data type.
2457+
or :obj:`dpnp.int16`. Omitting it results in the view having the
2458+
same data type. Can also be a subclass of :class:`dpnp.ndarray` to
2459+
create a view of that type (this is equivalent to setting the `type`
2460+
parameter).
2461+
2462+
Default: ``None``.
2463+
type : {None, type}, optional
2464+
Type of the returned view, e.g. a subclass of :class:`dpnp.ndarray`.
2465+
If specified, the returned array will be an instance of `type`.
2466+
Omitting it results in type preservation.
23292467
23302468
Default: ``None``.
23312469
@@ -2340,11 +2478,6 @@ def view(self, /, dtype=None, *, type=None):
23402478
23412479
Only the last axis has to be contiguous.
23422480
2343-
Limitations
2344-
-----------
2345-
Parameter `type` is supported only with default value ``None``.
2346-
Otherwise, the function raises ``NotImplementedError`` exception.
2347-
23482481
Examples
23492482
--------
23502483
>>> import dpnp as np
@@ -2368,73 +2501,17 @@ def view(self, /, dtype=None, *, type=None):
23682501
[[2312, 2826],
23692502
[5396, 5910]]], dtype=int16)
23702503
2371-
"""
2372-
2373-
if type is not None:
2374-
raise NotImplementedError(
2375-
"Keyword argument `type` is supported only with "
2376-
f"default value ``None``, but got {type}."
2377-
)
2378-
2379-
old_sh = self.shape
2380-
old_strides = self.strides
2381-
2382-
if dtype is None:
2383-
return dpnp_array(old_sh, buffer=self, strides=old_strides)
2384-
2385-
new_dt = dpnp.dtype(dtype)
2386-
new_dt = dtu._to_device_supported_dtype(new_dt, self.sycl_device)
2387-
2388-
new_itemsz = new_dt.itemsize
2389-
old_itemsz = self.dtype.itemsize
2390-
if new_itemsz == old_itemsz:
2391-
return dpnp_array(
2392-
old_sh, dtype=new_dt, buffer=self, strides=old_strides
2393-
)
2394-
2395-
ndim = self.ndim
2396-
if ndim == 0:
2397-
raise ValueError(
2398-
"Changing the dtype of a 0d array is only supported "
2399-
"if the itemsize is unchanged"
2400-
)
2401-
2402-
# resize on last axis only
2403-
axis = ndim - 1
2404-
if (
2405-
old_sh[axis] != 1
2406-
and self.size != 0
2407-
and old_strides[axis] != old_itemsz
2408-
):
2409-
raise ValueError(
2410-
"To change to a dtype of a different size, "
2411-
"the last axis must be contiguous"
2412-
)
2504+
Creating a view with a custom ndarray subclass:
24132505
2414-
# normalize strides whenever itemsize changes
2415-
new_strides = tuple(
2416-
old_strides[i] if i != axis else new_itemsz for i in range(ndim)
2417-
)
2418-
2419-
new_dim = old_sh[axis] * old_itemsz
2420-
if new_dim % new_itemsz != 0:
2421-
raise ValueError(
2422-
"When changing to a larger dtype, its size must be a divisor "
2423-
"of the total size in bytes of the last axis of the array"
2424-
)
2425-
2426-
# normalize shape whenever itemsize changes
2427-
new_sh = tuple(
2428-
old_sh[i] if i != axis else new_dim // new_itemsz
2429-
for i in range(ndim)
2430-
)
2506+
>>> class MyArray(np.ndarray):
2507+
... pass
2508+
>>> x = np.array([1, 2, 3])
2509+
>>> y = x.view(MyArray)
2510+
>>> type(y)
2511+
<class 'MyArray'>
24312512
2432-
return dpnp_array(
2433-
new_sh,
2434-
dtype=new_dt,
2435-
buffer=self,
2436-
strides=new_strides,
2437-
)
2513+
"""
2514+
return self._view_impl(dtype=dtype, array_class=type)
24382515

24392516
@property
24402517
def usm_type(self):

dpnp/tests/test_ndarray.py

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,10 +228,87 @@ def test_python_types(self, dt):
228228
expected = a.view(dt)
229229
assert_allclose(result, expected)
230230

231-
def test_type_error(self):
232-
x = dpnp.ones(4, dtype="i4")
233-
with pytest.raises(NotImplementedError):
234-
x.view("i2", type=dpnp.ndarray)
231+
def test_subclass_basic(self):
232+
class MyArray(dpnp.ndarray):
233+
pass
234+
235+
x = dpnp.array([1, 2, 3])
236+
view = x.view(type=MyArray)
237+
238+
assert isinstance(view, MyArray)
239+
assert type(view) is MyArray
240+
assert (view == x).all()
241+
242+
def test_dtype_type_subclass(self):
243+
class MyArray(dpnp.ndarray):
244+
pass
245+
246+
x = dpnp.array([1, 2, 3])
247+
248+
# All three syntaxes should work identically
249+
view1 = x.view(type=MyArray)
250+
view2 = x.view(MyArray)
251+
view3 = x.view(dtype=MyArray)
252+
253+
assert type(view1) is MyArray
254+
assert type(view2) is MyArray
255+
assert type(view3) is MyArray
256+
257+
def test_subclass_array_finalize(self):
258+
class ArrayWithInfo(dpnp.ndarray):
259+
def __array_finalize__(self, obj):
260+
self.info = getattr(obj, "info", "default")
261+
262+
x = dpnp.array([1, 2, 3]).view(type=ArrayWithInfo)
263+
x.info = "metadata"
264+
265+
# Create a view - __array_finalize__ should be called
266+
view = x.view()
267+
assert hasattr(view, "info")
268+
assert view.info == "metadata"
269+
assert type(view) is ArrayWithInfo
270+
271+
def test_subclass_self_class_preservation(self):
272+
class MyArray(dpnp.ndarray):
273+
pass
274+
275+
x = dpnp.array([1, 2, 3]).view(type=MyArray)
276+
277+
# View without type parameter should preserve MyArray
278+
view = x.view()
279+
assert type(view) is MyArray
280+
281+
def test_subclass_with_dtype_change(self):
282+
class MyArray(dpnp.ndarray):
283+
pass
284+
285+
x = dpnp.array([1.0, 2.0], dtype=dpnp.float32)
286+
view = x.view(dtype=dpnp.int32, type=MyArray)
287+
288+
assert type(view) is MyArray
289+
assert view.dtype == dpnp.int32
290+
291+
@pytest.mark.parametrize("xp", [dpnp, numpy])
292+
def test_subclass_invalid_type(self, xp):
293+
x = xp.array([1, 2, 3])
294+
with pytest.raises(
295+
ValueError, match="Type must be a sub-type of ndarray type"
296+
):
297+
x.view(type=list)
298+
299+
@pytest.mark.parametrize("xp", [dpnp, numpy])
300+
def test_subclass_double_type_specification(self, xp):
301+
class MyArray(xp.ndarray):
302+
pass
303+
304+
class OtherArray(xp.ndarray):
305+
pass
306+
307+
x = xp.array([1, 2, 3])
308+
with pytest.raises(
309+
ValueError, match="Cannot specify output type twice"
310+
):
311+
x.view(dtype=MyArray, type=OtherArray)
235312

236313

237314
@pytest.mark.parametrize(

dpnp/tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,6 @@ def __array_finalize__(self, obj):
466466
self.info = getattr(obj, "info", None)
467467

468468

469-
@pytest.mark.skip("subclass array is not supported")
470469
class TestSubclassArrayView:
471470

472471
def test_view_casting(self):

0 commit comments

Comments
 (0)