@@ -644,6 +644,136 @@ def _create_from_usm_ndarray(usm_ary: dpt.usm_ndarray):
644644 res ._array_obj ._set_namespace (dpnp )
645645 return res
646646
647+ def _create_view (self , array_class , shape , dtype , strides ):
648+ """
649+ Create a view of an array with the specified class.
650+
651+ The method handles subclass instantiation by creating a usm_ndarray
652+ view and then wrapping it in the appropriate class.
653+
654+ Parameters
655+ ----------
656+ array_class : type
657+ The class to instantiate (dpnp_array or a subclass).
658+ shape : tuple
659+ Shape of the view.
660+ dtype : dtype
661+ Data type of the view (can be None to keep source's dtype).
662+ strides : tuple
663+ Strides of the view.
664+
665+ Returns
666+ -------
667+ view : array_class instance
668+ A view of the array as the specified class.
669+
670+ """
671+
672+ if dtype is None :
673+ dtype = self .dtype
674+
675+ # create the underlying usm_ndarray view
676+ usm_view = dpt .usm_ndarray (
677+ shape ,
678+ dtype = dtype ,
679+ buffer = self ._array_obj ,
680+ strides = tuple (s // dpnp .dtype (dtype ).itemsize for s in strides ),
681+ )
682+
683+ # wrap the view into the appropriate class
684+ if array_class is dpnp_array :
685+ res = dpnp_array ._create_from_usm_ndarray (usm_view )
686+ else :
687+ # for subclasses, create using __new__ and set up manually
688+ res = array_class .__new__ (array_class )
689+ res ._array_obj = usm_view
690+ res ._array_obj ._set_namespace (dpnp )
691+
692+ if hasattr (res , "__array_finalize__" ):
693+ res .__array_finalize__ (self )
694+
695+ return res
696+
697+ def _view_impl (self , dtype = None , array_class = None ):
698+ """
699+ Internal implementation of view method to avoid an issue where
700+ `type` parameter in ndarray.view method shadowing builtin type.
701+
702+ """
703+
704+ # check if dtype is actually a type
705+ if dtype is not None :
706+ if isinstance (dtype , type ) and issubclass (dtype , dpnp_array ):
707+ if array_class is not None :
708+ raise ValueError ("Cannot specify output type twice" )
709+ array_class = dtype
710+ dtype = None
711+
712+ # validate array_class parameter
713+ if not (
714+ array_class is None
715+ or isinstance (array_class , type )
716+ and issubclass (array_class , dpnp_array )
717+ ):
718+ raise ValueError ("Type must be a sub-type of ndarray type" )
719+
720+ if array_class is None :
721+ # it's a view on dpnp.ndarray
722+ array_class = self .__class__
723+
724+ old_sh = self .shape
725+ old_strides = self .strides
726+
727+ if dtype is None :
728+ return self ._create_view (array_class , old_sh , None , old_strides )
729+
730+ new_dt = dpnp .dtype (dtype )
731+ new_dt = dtu ._to_device_supported_dtype (new_dt , self .sycl_device )
732+
733+ new_itemsz = new_dt .itemsize
734+ old_itemsz = self .dtype .itemsize
735+ if new_itemsz == old_itemsz :
736+ return self ._create_view (array_class , old_sh , new_dt , old_strides )
737+
738+ ndim = self .ndim
739+ if ndim == 0 :
740+ raise ValueError (
741+ "Changing the dtype of a 0d array is only supported "
742+ "if the itemsize is unchanged"
743+ )
744+
745+ # resize on last axis only
746+ axis = ndim - 1
747+ if (
748+ old_sh [axis ] != 1
749+ and self .size != 0
750+ and old_strides [axis ] != old_itemsz
751+ ):
752+ raise ValueError (
753+ "To change to a dtype of a different size, "
754+ "the last axis must be contiguous"
755+ )
756+
757+ # normalize strides whenever itemsize changes
758+ new_strides = tuple (
759+ old_strides [i ] if i != axis else new_itemsz for i in range (ndim )
760+ )
761+
762+ new_dim = old_sh [axis ] * old_itemsz
763+ if new_dim % new_itemsz != 0 :
764+ raise ValueError (
765+ "When changing to a larger dtype, its size must be a divisor "
766+ "of the total size in bytes of the last axis of the array"
767+ )
768+
769+ # normalize shape whenever itemsize changes
770+ new_sh = tuple (
771+ old_sh [i ] if i != axis else new_dim // new_itemsz
772+ for i in range (ndim )
773+ )
774+
775+ return self ._create_view (array_class , new_sh , new_dt , new_strides )
776+
647777 def all (self , axis = None , * , out = None , keepdims = False , where = True ):
648778 """
649779 Return ``True`` if all elements evaluate to ``True``.
@@ -2322,10 +2452,18 @@ def view(self, /, dtype=None, *, type=None):
23222452
23232453 Parameters
23242454 ----------
2325- dtype : {None, str, dtype object}, optional
2455+ dtype : {None, str, dtype object, type }, optional
23262456 The desired data type of the returned view, e.g. :obj:`dpnp.float32`
2327- or :obj:`dpnp.int16`. By default, it results in the view having the
2328- same data type.
2457+ or :obj:`dpnp.int16`. Omitting it results in the view having the
2458+ same data type. Can also be a subclass of :class:`dpnp.ndarray` to
2459+ create a view of that type (this is equivalent to setting the `type`
2460+ parameter).
2461+
2462+ Default: ``None``.
2463+ type : {None, type}, optional
2464+ Type of the returned view, e.g. a subclass of :class:`dpnp.ndarray`.
2465+ If specified, the returned array will be an instance of `type`.
2466+ Omitting it results in type preservation.
23292467
23302468 Default: ``None``.
23312469
@@ -2340,11 +2478,6 @@ def view(self, /, dtype=None, *, type=None):
23402478
23412479 Only the last axis has to be contiguous.
23422480
2343- Limitations
2344- -----------
2345- Parameter `type` is supported only with default value ``None``.
2346- Otherwise, the function raises ``NotImplementedError`` exception.
2347-
23482481 Examples
23492482 --------
23502483 >>> import dpnp as np
@@ -2368,73 +2501,17 @@ def view(self, /, dtype=None, *, type=None):
23682501 [[2312, 2826],
23692502 [5396, 5910]]], dtype=int16)
23702503
2371- """
2372-
2373- if type is not None :
2374- raise NotImplementedError (
2375- "Keyword argument `type` is supported only with "
2376- f"default value ``None``, but got { type } ."
2377- )
2378-
2379- old_sh = self .shape
2380- old_strides = self .strides
2381-
2382- if dtype is None :
2383- return dpnp_array (old_sh , buffer = self , strides = old_strides )
2384-
2385- new_dt = dpnp .dtype (dtype )
2386- new_dt = dtu ._to_device_supported_dtype (new_dt , self .sycl_device )
2387-
2388- new_itemsz = new_dt .itemsize
2389- old_itemsz = self .dtype .itemsize
2390- if new_itemsz == old_itemsz :
2391- return dpnp_array (
2392- old_sh , dtype = new_dt , buffer = self , strides = old_strides
2393- )
2394-
2395- ndim = self .ndim
2396- if ndim == 0 :
2397- raise ValueError (
2398- "Changing the dtype of a 0d array is only supported "
2399- "if the itemsize is unchanged"
2400- )
2401-
2402- # resize on last axis only
2403- axis = ndim - 1
2404- if (
2405- old_sh [axis ] != 1
2406- and self .size != 0
2407- and old_strides [axis ] != old_itemsz
2408- ):
2409- raise ValueError (
2410- "To change to a dtype of a different size, "
2411- "the last axis must be contiguous"
2412- )
2504+ Creating a view with a custom ndarray subclass:
24132505
2414- # normalize strides whenever itemsize changes
2415- new_strides = tuple (
2416- old_strides [i ] if i != axis else new_itemsz for i in range (ndim )
2417- )
2418-
2419- new_dim = old_sh [axis ] * old_itemsz
2420- if new_dim % new_itemsz != 0 :
2421- raise ValueError (
2422- "When changing to a larger dtype, its size must be a divisor "
2423- "of the total size in bytes of the last axis of the array"
2424- )
2425-
2426- # normalize shape whenever itemsize changes
2427- new_sh = tuple (
2428- old_sh [i ] if i != axis else new_dim // new_itemsz
2429- for i in range (ndim )
2430- )
2506+ >>> class MyArray(np.ndarray):
2507+ ... pass
2508+ >>> x = np.array([1, 2, 3])
2509+ >>> y = x.view(MyArray)
2510+ >>> type(y)
2511+ <class 'MyArray'>
24312512
2432- return dpnp_array (
2433- new_sh ,
2434- dtype = new_dt ,
2435- buffer = self ,
2436- strides = new_strides ,
2437- )
2513+ """
2514+ return self ._view_impl (dtype = dtype , array_class = type )
24382515
24392516 @property
24402517 def usm_type (self ):
0 commit comments