@@ -313,37 +313,45 @@ def test_repeat(x, kw, data):
313313
314314 assume (n_repititions <= hh .SQRT_MAX_ARRAY_SIZE )
315315
316- out = xp .repeat (x , repeats , ** kw )
317- ph .assert_dtype ("repeat" , in_dtype = x .dtype , out_dtype = out .dtype )
318- if axis is None :
319- expected_shape = (n_repititions ,)
320- else :
321- expected_shape = list (shape )
322- expected_shape [axis ] = n_repititions
323- expected_shape = tuple (expected_shape )
324- ph .assert_shape ("repeat" , out_shape = out .shape , expected = expected_shape )
316+ repro_snippet = ph .format_snippet (f"xp.repeat({ x !r} ,{ repeats !r} , **kw) with { kw = } " )
317+ try :
318+ out = xp .repeat (x , repeats , ** kw )
325319
326- # Test values
320+ ph .assert_dtype ("repeat" , in_dtype = x .dtype , out_dtype = out .dtype )
321+ if axis is None :
322+ expected_shape = (n_repititions ,)
323+ else :
324+ expected_shape = list (shape )
325+ expected_shape [axis ] = n_repititions
326+ expected_shape = tuple (expected_shape )
327+ ph .assert_shape ("repeat" , out_shape = out .shape , expected = expected_shape )
328+
329+ # Test values
330+
331+ if isinstance (repeats , int ):
332+ repeats_array = xp .full (size , repeats , dtype = xp .int32 )
333+ else :
334+ repeats_array = repeats
335+
336+ if kw .get ("axis" ) is None :
337+ x = xp .reshape (x , (- 1 ,))
338+ axis = 0
339+
340+ for idx , in sh .iter_indices (x .shape , skip_axes = axis ):
341+ x_slice = x [idx ]
342+ out_slice = out [idx ]
343+ start = 0
344+ for i , count in enumerate (repeats_array ):
345+ end = start + count
346+ ph .assert_array_elements ("repeat" , out = out_slice [start :end ],
347+ expected = xp .full ((count ,), x_slice [i ], dtype = x .dtype ),
348+ kw = kw )
349+ start = end
350+
351+ except Exception as exc :
352+ exc .add_note (repro_snippet )
353+ raise
327354
328- if isinstance (repeats , int ):
329- repeats_array = xp .full (size , repeats , dtype = xp .int32 )
330- else :
331- repeats_array = repeats
332-
333- if kw .get ("axis" ) is None :
334- x = xp .reshape (x , (- 1 ,))
335- axis = 0
336-
337- for idx , in sh .iter_indices (x .shape , skip_axes = axis ):
338- x_slice = x [idx ]
339- out_slice = out [idx ]
340- start = 0
341- for i , count in enumerate (repeats_array ):
342- end = start + count
343- ph .assert_array_elements ("repeat" , out = out_slice [start :end ],
344- expected = xp .full ((count ,), x_slice [i ], dtype = x .dtype ),
345- kw = kw )
346- start = end
347355
348356reshape_shape = st .shared (hh .shapes (), key = "reshape_shape" )
349357
0 commit comments