Skip to content

Commit 1ea81a1

Browse files
authored
Add tf.RaggedTensor support to Embedding layer. (#21171)
Adds support for indices indices in the form of a `tf.RaggedTensor` to the `Embedding` layer by adding support to `ops.take`. The output is also ragged. Also: - adds support for negative indices in the sparse tensor use case. - adds support for ragged tensors in `TestCase.run_layer_test`.
1 parent 128e280 commit 1ea81a1

File tree

6 files changed

+169
-34
lines changed

6 files changed

+169
-34
lines changed

keras/src/backend/tensorflow/numpy.py

+30-22
Original file line numberDiff line numberDiff line change
@@ -2185,42 +2185,50 @@ def swapaxes(x, axis1, axis2):
21852185

21862186

21872187
def take(x, indices, axis=None):
2188+
x = convert_to_tensor(x)
2189+
if axis is None:
2190+
x = tf.reshape(x, (-1,))
2191+
axis = 0
2192+
2193+
def fix_negative_indices(i):
2194+
# Correct the indices using "fill" mode which is the same as in jax
2195+
return tf.where(i < 0, i + tf.cast(tf.shape(x)[axis], i.dtype), i)
2196+
21882197
if isinstance(indices, tf.SparseTensor):
21892198
if x.dtype not in (tf.float16, tf.float32, tf.float64, tf.bfloat16):
21902199
warnings.warn(
21912200
"`take` with the TensorFlow backend does not support "
21922201
f"`x.dtype={x.dtype}` when `indices` is a sparse tensor; "
21932202
"densifying `indices`."
21942203
)
2195-
return take(x, convert_to_tensor(indices, sparse=False), axis=axis)
2196-
if axis is None:
2197-
x = tf.reshape(x, (-1,))
2204+
indices = convert_to_tensor(indices, sparse=False)
21982205
elif axis != 0:
21992206
warnings.warn(
22002207
"`take` with the TensorFlow backend does not support "
22012208
f"`axis={axis}` when `indices` is a sparse tensor; "
22022209
"densifying `indices`."
22032210
)
2204-
return take(x, convert_to_tensor(indices, sparse=False), axis=axis)
2205-
output = tf.nn.safe_embedding_lookup_sparse(
2206-
embedding_weights=tf.convert_to_tensor(x),
2207-
sparse_ids=tf.sparse.expand_dims(indices, axis=-1),
2208-
default_id=0,
2209-
)
2210-
output.set_shape(indices.shape + output.shape[len(indices.shape) :])
2211-
return output
2211+
indices = convert_to_tensor(indices, sparse=False)
2212+
else:
2213+
indices = sparse.sparse_with_values(
2214+
indices, fix_negative_indices(indices.values)
2215+
)
2216+
# `expand_dims` on `indices` prevents combiner from being applied.
2217+
output = tf.nn.safe_embedding_lookup_sparse(
2218+
embedding_weights=tf.convert_to_tensor(x),
2219+
sparse_ids=tf.sparse.expand_dims(indices, axis=-1),
2220+
default_id=0,
2221+
)
2222+
output.set_shape(indices.shape + output.shape[len(indices.shape) :])
2223+
return output
2224+
elif isinstance(indices, tf.RaggedTensor):
2225+
indices = indices.with_values(fix_negative_indices(indices.values))
2226+
if axis == 0:
2227+
return tf.nn.embedding_lookup(x, indices)
2228+
else:
2229+
return tf.gather(x, indices, axis=axis)
22122230

2213-
x = convert_to_tensor(x)
2214-
indices = convert_to_tensor(indices)
2215-
if axis is None:
2216-
x = tf.reshape(x, [-1])
2217-
axis = 0
2218-
# Correct the indices using "fill" mode which is the same as in jax
2219-
indices = tf.where(
2220-
indices < 0,
2221-
indices + tf.cast(tf.shape(x)[axis], indices.dtype),
2222-
indices,
2223-
)
2231+
indices = fix_negative_indices(convert_to_tensor(indices))
22242232
return tf.gather(x, indices, axis=axis)
22252233

22262234

keras/src/layers/core/embedding.py

+7
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from keras.src import quantizers
99
from keras.src import regularizers
1010
from keras.src.api_export import keras_export
11+
from keras.src.backend import KerasTensor
1112
from keras.src.layers.layer import Layer
1213

1314

@@ -156,6 +157,12 @@ def compute_mask(self, inputs, mask=None):
156157
def compute_output_shape(self, input_shape):
157158
return (*input_shape, self.output_dim)
158159

160+
def compute_output_spec(self, inputs):
161+
output_shape = (*inputs.shape, self.output_dim)
162+
return KerasTensor(
163+
output_shape, dtype=self.compute_dtype, ragged=inputs.ragged
164+
)
165+
159166
def enable_lora(
160167
self,
161168
rank,

keras/src/layers/core/embedding_test.py

+21
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,27 @@ def test_sparse(self):
6161
supports_masking=False,
6262
)
6363

64+
@pytest.mark.skipif(
65+
not backend.SUPPORTS_RAGGED_TENSORS,
66+
reason="Backend does not support ragged tensors.",
67+
)
68+
def test_ragged(self):
69+
self.run_layer_test(
70+
layers.Embedding,
71+
{"input_dim": 5, "output_dim": 4},
72+
input_shape=(2, 3),
73+
input_dtype="int32",
74+
input_ragged=True,
75+
expected_output_shape=(2, None, 4),
76+
expected_output_ragged=True,
77+
expected_num_trainable_weights=1,
78+
expected_num_non_trainable_weights=0,
79+
expected_num_seed_generators=0,
80+
expected_num_losses=0,
81+
supports_masking=False,
82+
# run_training_check=False,
83+
)
84+
6485
def test_correctness(self):
6586
layer = layers.Embedding(input_dim=3, output_dim=2)
6687
layer.build()

keras/src/ops/numpy.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -5425,15 +5425,17 @@ def compute_output_spec(self, x, indices):
54255425
x_shape = list(x.shape)
54265426
if isinstance(indices, KerasTensor):
54275427
indices_shape = list(indices.shape)
5428+
ragged = indices.ragged
54285429
else:
54295430
indices_shape = list(getattr(np.array(indices), "shape", []))
5431+
ragged = False
54305432
if self.axis is None:
54315433
return KerasTensor(indices_shape, dtype=x.dtype)
54325434

54335435
# make sure axis is non-negative
54345436
axis = len(x_shape) + self.axis if self.axis < 0 else self.axis
54355437
output_shape = x_shape[:axis] + indices_shape + x_shape[axis + 1 :]
5436-
return KerasTensor(output_shape, dtype=x.dtype)
5438+
return KerasTensor(output_shape, dtype=x.dtype, ragged=ragged)
54375439

54385440

54395441
@keras_export(["keras.ops.take", "keras.ops.numpy.take"])

keras/src/ops/numpy_test.py

+46-2
Original file line numberDiff line numberDiff line change
@@ -3078,17 +3078,61 @@ def test_take_sparse(self, dtype, axis):
30783078
if backend.backend() == "tensorflow":
30793079
import tensorflow as tf
30803080

3081-
indices = tf.SparseTensor([[0, 0], [1, 2]], [1, 2], (2, 3))
3081+
indices = tf.SparseTensor([[0, 0], [1, 2]], [-1, 2], (2, 3))
30823082
elif backend.backend() == "jax":
30833083
import jax.experimental.sparse as jax_sparse
30843084

3085-
indices = jax_sparse.BCOO(([1, 2], [[0, 0], [1, 2]]), shape=(2, 3))
3085+
indices = jax_sparse.BCOO(([-1, 2], [[0, 0], [1, 2]]), shape=(2, 3))
30863086

30873087
self.assertAllClose(
30883088
knp.take(x, indices, axis=axis),
30893089
np.take(x, backend.convert_to_numpy(indices), axis=axis),
30903090
)
30913091

3092+
@parameterized.named_parameters(
3093+
named_product(
3094+
[
3095+
{"testcase_name": "axis_none", "axis": None},
3096+
{"testcase_name": "axis_0", "axis": 0},
3097+
{"testcase_name": "axis_1", "axis": 1},
3098+
{"testcase_name": "axis_minus1", "axis": -1},
3099+
],
3100+
dtype=[
3101+
"float16",
3102+
"float32",
3103+
"float64",
3104+
"uint8",
3105+
"int8",
3106+
"int16",
3107+
"int32",
3108+
],
3109+
)
3110+
)
3111+
@pytest.mark.skipif(
3112+
not backend.SUPPORTS_RAGGED_TENSORS,
3113+
reason="Backend does not support ragged tensors.",
3114+
)
3115+
def test_take_ragged(self, dtype, axis):
3116+
rng = np.random.default_rng(0)
3117+
x = (4 * rng.standard_normal((3, 4, 5))).astype(dtype)
3118+
3119+
if backend.backend() == "tensorflow":
3120+
import tensorflow as tf
3121+
3122+
indices = tf.ragged.constant([[2], [0, -1, 1]])
3123+
mask = backend.convert_to_numpy(tf.ones_like(indices))
3124+
3125+
if axis == 0:
3126+
mask = np.expand_dims(mask, (2, 3))
3127+
elif axis == 1:
3128+
mask = np.expand_dims(mask, (2,))
3129+
3130+
self.assertAllClose(
3131+
knp.take(x, indices, axis=axis),
3132+
np.take(x, backend.convert_to_numpy(indices), axis=axis)
3133+
* mask.astype(dtype),
3134+
)
3135+
30923136
def test_take_along_axis(self):
30933137
x = np.arange(24).reshape([1, 2, 3, 4])
30943138
indices = np.ones([1, 4, 1, 1], dtype=np.int32)

keras/src/testing/test_case.py

+62-9
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from keras.src.backend.common import standardize_dtype
1717
from keras.src.backend.common.global_state import clear_session
1818
from keras.src.backend.common.keras_tensor import KerasTensor
19+
from keras.src.losses.losses import MeanSquaredError
1920
from keras.src.models import Model
2021
from keras.src.utils import traceback_utils
2122

@@ -100,6 +101,22 @@ def assertSparse(self, x, sparse=True):
100101
f"Backend {backend.backend()} does not support sparse tensors",
101102
)
102103

104+
def assertRagged(self, x, ragged=True):
105+
if isinstance(x, KerasTensor):
106+
self.assertEqual(x.ragged, ragged)
107+
elif backend.backend() == "tensorflow":
108+
import tensorflow as tf
109+
110+
if ragged:
111+
self.assertIsInstance(x, tf.RaggedTensor)
112+
else:
113+
self.assertNotIsInstance(x, tf.RaggedTensor)
114+
else:
115+
self.assertFalse(
116+
ragged,
117+
f"Backend {backend.backend()} does not support ragged tensors",
118+
)
119+
103120
def assertDType(self, x, dtype, msg=None):
104121
if hasattr(x, "dtype"):
105122
x_dtype = backend.standardize_dtype(x.dtype)
@@ -159,11 +176,13 @@ def run_layer_test(
159176
input_shape=None,
160177
input_dtype=None,
161178
input_sparse=False,
179+
input_ragged=False,
162180
input_data=None,
163181
call_kwargs=None,
164182
expected_output_shape=None,
165183
expected_output_dtype=None,
166184
expected_output_sparse=False,
185+
expected_output_ragged=False,
167186
expected_output=None,
168187
expected_num_trainable_weights=None,
169188
expected_num_non_trainable_weights=None,
@@ -188,6 +207,8 @@ def run_layer_test(
188207
input_dtype: Corresponding input dtype.
189208
input_sparse: Whether the input is a sparse tensor (this requires
190209
the backend to support sparse tensors).
210+
input_ragged: Whether the input is a ragged tensor (this requires
211+
the backend to support ragged tensors).
191212
input_data: Tensor (or list/dict of tensors)
192213
to call the layer on.
193214
call_kwargs: Dict of arguments to use when calling the
@@ -198,6 +219,8 @@ def run_layer_test(
198219
expected_output_dtype: dtype expected as output.
199220
expected_output_sparse: Whether the output is expected to be sparse
200221
(this requires the backend to support sparse tensors).
222+
expected_output_ragged: Whether the output is expected to be ragged
223+
(this requires the backend to support ragged tensors).
201224
expected_output: Expected output tensor -- only
202225
to be specified if input_data is provided.
203226
expected_num_trainable_weights: Expected number
@@ -280,7 +303,7 @@ def run_layer_test(
280303
if input_data is not None or input_shape is not None:
281304
if input_data is None:
282305
input_data = create_eager_tensors(
283-
input_shape, input_dtype, input_sparse
306+
input_shape, input_dtype, input_sparse, input_ragged
284307
)
285308
layer = layer_cls(**init_kwargs)
286309
if isinstance(input_data, dict):
@@ -357,7 +380,13 @@ def run_output_asserts(layer, output, eager=False):
357380
if expected_output_shape is not None:
358381

359382
def verify_shape(expected_shape, x):
360-
return expected_shape == x.shape
383+
shape = x.shape
384+
if len(shape) != len(expected_shape):
385+
return False
386+
for expected_dim, dim in zip(expected_shape, shape):
387+
if expected_dim is not None and expected_dim != dim:
388+
return False
389+
return True
361390

362391
shapes_match = tree.map_structure_up_to(
363392
output, verify_shape, expected_output_shape, output
@@ -383,6 +412,9 @@ def verify_dtype(expected_dtype, x):
383412
if expected_output_sparse:
384413
for x in tree.flatten(output):
385414
self.assertSparse(x)
415+
if expected_output_ragged:
416+
for x in tree.flatten(output):
417+
self.assertRagged(x)
386418
if eager:
387419
if expected_output is not None:
388420
self.assertEqual(type(expected_output), type(output))
@@ -430,7 +462,11 @@ def data_generator():
430462
jit_compile = "auto"
431463
if backend.backend() == "tensorflow" and input_sparse:
432464
jit_compile = False
433-
model.compile(optimizer="sgd", loss="mse", jit_compile=jit_compile)
465+
model.compile(
466+
optimizer="sgd",
467+
loss=MeanSquaredError(reduction="sum"),
468+
jit_compile=jit_compile,
469+
)
434470
model.fit(data_generator(), steps_per_epoch=1, verbose=0)
435471

436472
# Build test.
@@ -452,13 +488,13 @@ def data_generator():
452488
if input_shape is None:
453489
keras_tensor_inputs = tree.map_structure(
454490
lambda x: create_keras_tensors(
455-
ops.shape(x), x.dtype, input_sparse
491+
ops.shape(x), x.dtype, input_sparse, input_ragged
456492
),
457493
input_data,
458494
)
459495
else:
460496
keras_tensor_inputs = create_keras_tensors(
461-
input_shape, input_dtype, input_sparse
497+
input_shape, input_dtype, input_sparse, input_ragged
462498
)
463499
layer = layer_cls(**init_kwargs)
464500
if isinstance(keras_tensor_inputs, dict):
@@ -589,22 +625,24 @@ def uses_cpu():
589625
return False
590626

591627

592-
def create_keras_tensors(input_shape, dtype, sparse):
628+
def create_keras_tensors(input_shape, dtype, sparse, ragged):
593629
if isinstance(input_shape, dict):
594630
return {
595631
utils.removesuffix(k, "_shape"): KerasTensor(
596-
v, dtype=dtype[k], sparse=sparse
632+
v, dtype=dtype[k], sparse=sparse, ragged=ragged
597633
)
598634
for k, v in input_shape.items()
599635
}
600636
return map_shape_dtype_structure(
601-
lambda shape, dt: KerasTensor(shape, dtype=dt, sparse=sparse),
637+
lambda shape, dt: KerasTensor(
638+
shape, dtype=dt, sparse=sparse, ragged=ragged
639+
),
602640
input_shape,
603641
dtype,
604642
)
605643

606644

607-
def create_eager_tensors(input_shape, dtype, sparse):
645+
def create_eager_tensors(input_shape, dtype, sparse, ragged):
608646
from keras.src.backend import random
609647

610648
if set(tree.flatten(dtype)).difference(
@@ -651,6 +689,21 @@ def create_fn(shape, dt):
651689
f"Sparse is unsupported with backend {backend.backend()}"
652690
)
653691

692+
elif ragged:
693+
if backend.backend() == "tensorflow":
694+
import tensorflow as tf
695+
696+
def create_fn(shape, dt):
697+
rng = np.random.default_rng(0)
698+
x = (4 * rng.standard_normal(shape)).astype(dt)
699+
x = np.multiply(x, rng.random(shape) < 0.7)
700+
return tf.RaggedTensor.from_tensor(x, padding=0)
701+
702+
else:
703+
raise ValueError(
704+
f"Ragged is unsupported with backend {backend.backend()}"
705+
)
706+
654707
else:
655708

656709
def create_fn(shape, dt):

0 commit comments

Comments
 (0)