Skip to content

Fix Embedding test with ragged tensors on GPU. #21177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 17, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions keras/src/testing/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from keras.src.backend.common import standardize_dtype
from keras.src.backend.common.global_state import clear_session
from keras.src.backend.common.keras_tensor import KerasTensor
from keras.src.losses.losses import MeanSquaredError
from keras.src.losses.loss import Loss
from keras.src.models import Model
from keras.src.utils import traceback_utils

Expand Down Expand Up @@ -446,6 +446,11 @@ def data_generator():
while True:
yield data

# Single op loss to avoid compilation issues with ragged / sparse.
class TestLoss(Loss):
def __call__(self, y_true, y_pred, sample_weight=None):
return ops.sum(y_pred)

# test the "default" path for each backend by setting
# jit_compile="auto".
# for tensorflow and jax backends auto is jitted
Expand All @@ -463,9 +468,7 @@ def data_generator():
if backend.backend() == "tensorflow" and input_sparse:
jit_compile = False
model.compile(
optimizer="sgd",
loss=MeanSquaredError(reduction="sum"),
jit_compile=jit_compile,
optimizer="sgd", loss=TestLoss(), jit_compile=jit_compile
)
model.fit(data_generator(), steps_per_epoch=1, verbose=0)

Expand Down