16
16
from keras .src .backend .common import standardize_dtype
17
17
from keras .src .backend .common .global_state import clear_session
18
18
from keras .src .backend .common .keras_tensor import KerasTensor
19
+ from keras .src .losses .losses import MeanSquaredError
19
20
from keras .src .models import Model
20
21
from keras .src .utils import traceback_utils
21
22
@@ -100,6 +101,22 @@ def assertSparse(self, x, sparse=True):
100
101
f"Backend { backend .backend ()} does not support sparse tensors" ,
101
102
)
102
103
104
+ def assertRagged (self , x , ragged = True ):
105
+ if isinstance (x , KerasTensor ):
106
+ self .assertEqual (x .ragged , ragged )
107
+ elif backend .backend () == "tensorflow" :
108
+ import tensorflow as tf
109
+
110
+ if ragged :
111
+ self .assertIsInstance (x , tf .RaggedTensor )
112
+ else :
113
+ self .assertNotIsInstance (x , tf .RaggedTensor )
114
+ else :
115
+ self .assertFalse (
116
+ ragged ,
117
+ f"Backend { backend .backend ()} does not support ragged tensors" ,
118
+ )
119
+
103
120
def assertDType (self , x , dtype , msg = None ):
104
121
if hasattr (x , "dtype" ):
105
122
x_dtype = backend .standardize_dtype (x .dtype )
@@ -159,11 +176,13 @@ def run_layer_test(
159
176
input_shape = None ,
160
177
input_dtype = None ,
161
178
input_sparse = False ,
179
+ input_ragged = False ,
162
180
input_data = None ,
163
181
call_kwargs = None ,
164
182
expected_output_shape = None ,
165
183
expected_output_dtype = None ,
166
184
expected_output_sparse = False ,
185
+ expected_output_ragged = False ,
167
186
expected_output = None ,
168
187
expected_num_trainable_weights = None ,
169
188
expected_num_non_trainable_weights = None ,
@@ -188,6 +207,8 @@ def run_layer_test(
188
207
input_dtype: Corresponding input dtype.
189
208
input_sparse: Whether the input is a sparse tensor (this requires
190
209
the backend to support sparse tensors).
210
+ input_ragged: Whether the input is a ragged tensor (this requires
211
+ the backend to support ragged tensors).
191
212
input_data: Tensor (or list/dict of tensors)
192
213
to call the layer on.
193
214
call_kwargs: Dict of arguments to use when calling the
@@ -198,6 +219,8 @@ def run_layer_test(
198
219
expected_output_dtype: dtype expected as output.
199
220
expected_output_sparse: Whether the output is expected to be sparse
200
221
(this requires the backend to support sparse tensors).
222
+ expected_output_ragged: Whether the output is expected to be ragged
223
+ (this requires the backend to support ragged tensors).
201
224
expected_output: Expected output tensor -- only
202
225
to be specified if input_data is provided.
203
226
expected_num_trainable_weights: Expected number
@@ -280,7 +303,7 @@ def run_layer_test(
280
303
if input_data is not None or input_shape is not None :
281
304
if input_data is None :
282
305
input_data = create_eager_tensors (
283
- input_shape , input_dtype , input_sparse
306
+ input_shape , input_dtype , input_sparse , input_ragged
284
307
)
285
308
layer = layer_cls (** init_kwargs )
286
309
if isinstance (input_data , dict ):
@@ -357,7 +380,13 @@ def run_output_asserts(layer, output, eager=False):
357
380
if expected_output_shape is not None :
358
381
359
382
def verify_shape (expected_shape , x ):
360
- return expected_shape == x .shape
383
+ shape = x .shape
384
+ if len (shape ) != len (expected_shape ):
385
+ return False
386
+ for expected_dim , dim in zip (expected_shape , shape ):
387
+ if expected_dim is not None and expected_dim != dim :
388
+ return False
389
+ return True
361
390
362
391
shapes_match = tree .map_structure_up_to (
363
392
output , verify_shape , expected_output_shape , output
@@ -383,6 +412,9 @@ def verify_dtype(expected_dtype, x):
383
412
if expected_output_sparse :
384
413
for x in tree .flatten (output ):
385
414
self .assertSparse (x )
415
+ if expected_output_ragged :
416
+ for x in tree .flatten (output ):
417
+ self .assertRagged (x )
386
418
if eager :
387
419
if expected_output is not None :
388
420
self .assertEqual (type (expected_output ), type (output ))
@@ -430,7 +462,11 @@ def data_generator():
430
462
jit_compile = "auto"
431
463
if backend .backend () == "tensorflow" and input_sparse :
432
464
jit_compile = False
433
- model .compile (optimizer = "sgd" , loss = "mse" , jit_compile = jit_compile )
465
+ model .compile (
466
+ optimizer = "sgd" ,
467
+ loss = MeanSquaredError (reduction = "sum" ),
468
+ jit_compile = jit_compile ,
469
+ )
434
470
model .fit (data_generator (), steps_per_epoch = 1 , verbose = 0 )
435
471
436
472
# Build test.
@@ -452,13 +488,13 @@ def data_generator():
452
488
if input_shape is None :
453
489
keras_tensor_inputs = tree .map_structure (
454
490
lambda x : create_keras_tensors (
455
- ops .shape (x ), x .dtype , input_sparse
491
+ ops .shape (x ), x .dtype , input_sparse , input_ragged
456
492
),
457
493
input_data ,
458
494
)
459
495
else :
460
496
keras_tensor_inputs = create_keras_tensors (
461
- input_shape , input_dtype , input_sparse
497
+ input_shape , input_dtype , input_sparse , input_ragged
462
498
)
463
499
layer = layer_cls (** init_kwargs )
464
500
if isinstance (keras_tensor_inputs , dict ):
@@ -589,22 +625,24 @@ def uses_cpu():
589
625
return False
590
626
591
627
592
- def create_keras_tensors (input_shape , dtype , sparse ):
628
+ def create_keras_tensors (input_shape , dtype , sparse , ragged ):
593
629
if isinstance (input_shape , dict ):
594
630
return {
595
631
utils .removesuffix (k , "_shape" ): KerasTensor (
596
- v , dtype = dtype [k ], sparse = sparse
632
+ v , dtype = dtype [k ], sparse = sparse , ragged = ragged
597
633
)
598
634
for k , v in input_shape .items ()
599
635
}
600
636
return map_shape_dtype_structure (
601
- lambda shape , dt : KerasTensor (shape , dtype = dt , sparse = sparse ),
637
+ lambda shape , dt : KerasTensor (
638
+ shape , dtype = dt , sparse = sparse , ragged = ragged
639
+ ),
602
640
input_shape ,
603
641
dtype ,
604
642
)
605
643
606
644
607
- def create_eager_tensors (input_shape , dtype , sparse ):
645
+ def create_eager_tensors (input_shape , dtype , sparse , ragged ):
608
646
from keras .src .backend import random
609
647
610
648
if set (tree .flatten (dtype )).difference (
@@ -651,6 +689,21 @@ def create_fn(shape, dt):
651
689
f"Sparse is unsupported with backend { backend .backend ()} "
652
690
)
653
691
692
+ elif ragged :
693
+ if backend .backend () == "tensorflow" :
694
+ import tensorflow as tf
695
+
696
+ def create_fn (shape , dt ):
697
+ rng = np .random .default_rng (0 )
698
+ x = (4 * rng .standard_normal (shape )).astype (dt )
699
+ x = np .multiply (x , rng .random (shape ) < 0.7 )
700
+ return tf .RaggedTensor .from_tensor (x , padding = 0 )
701
+
702
+ else :
703
+ raise ValueError (
704
+ f"Ragged is unsupported with backend { backend .backend ()} "
705
+ )
706
+
654
707
else :
655
708
656
709
def create_fn (shape , dt ):
0 commit comments