Skip to content

Commit b840164

Browse files
authored
Fix Embedding test with ragged tensors on GPU. (#21177)
The loss needs to not have any non-compilable op.
1 parent 8e4b4ab commit b840164

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

keras/src/testing/test_case.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from keras.src.backend.common import standardize_dtype
1717
from keras.src.backend.common.global_state import clear_session
1818
from keras.src.backend.common.keras_tensor import KerasTensor
19-
from keras.src.losses.losses import MeanSquaredError
19+
from keras.src.losses.loss import Loss
2020
from keras.src.models import Model
2121
from keras.src.utils import traceback_utils
2222

@@ -446,6 +446,11 @@ def data_generator():
446446
while True:
447447
yield data
448448

449+
# Single op loss to avoid compilation issues with ragged / sparse.
450+
class TestLoss(Loss):
451+
def __call__(self, y_true, y_pred, sample_weight=None):
452+
return ops.sum(y_pred)
453+
449454
# test the "default" path for each backend by setting
450455
# jit_compile="auto".
451456
# for tensorflow and jax backends auto is jitted
@@ -463,9 +468,7 @@ def data_generator():
463468
if backend.backend() == "tensorflow" and input_sparse:
464469
jit_compile = False
465470
model.compile(
466-
optimizer="sgd",
467-
loss=MeanSquaredError(reduction="sum"),
468-
jit_compile=jit_compile,
471+
optimizer="sgd", loss=TestLoss(), jit_compile=jit_compile
469472
)
470473
model.fit(data_generator(), steps_per_epoch=1, verbose=0)
471474

0 commit comments

Comments
 (0)