File tree 1 file changed +7
-4
lines changed
1 file changed +7
-4
lines changed Original file line number Diff line number Diff line change 16
16
from keras .src .backend .common import standardize_dtype
17
17
from keras .src .backend .common .global_state import clear_session
18
18
from keras .src .backend .common .keras_tensor import KerasTensor
19
- from keras .src .losses .losses import MeanSquaredError
19
+ from keras .src .losses .loss import Loss
20
20
from keras .src .models import Model
21
21
from keras .src .utils import traceback_utils
22
22
@@ -446,6 +446,11 @@ def data_generator():
446
446
while True :
447
447
yield data
448
448
449
+ # Single op loss to avoid compilation issues with ragged / sparse.
450
+ class TestLoss (Loss ):
451
+ def __call__ (self , y_true , y_pred , sample_weight = None ):
452
+ return ops .sum (y_pred )
453
+
449
454
# test the "default" path for each backend by setting
450
455
# jit_compile="auto".
451
456
# for tensorflow and jax backends auto is jitted
@@ -463,9 +468,7 @@ def data_generator():
463
468
if backend .backend () == "tensorflow" and input_sparse :
464
469
jit_compile = False
465
470
model .compile (
466
- optimizer = "sgd" ,
467
- loss = MeanSquaredError (reduction = "sum" ),
468
- jit_compile = jit_compile ,
471
+ optimizer = "sgd" , loss = TestLoss (), jit_compile = jit_compile
469
472
)
470
473
model .fit (data_generator (), steps_per_epoch = 1 , verbose = 0 )
471
474
You can’t perform that action at this time.
0 commit comments