-
Notifications
You must be signed in to change notification settings - Fork 19.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
clarification of using preprocessing with tf.data with multiple backend #21029
Comments
You mean tf.data as data generator? |
@edge7 Say
|
@edge7 you are using Now, I was digging the source code and I found that the built in processing layer are using some sort of dynamic backend setting method. And that probably the key here but I'm not sure yet. Take a loot the following example, it works with tensorflow baackend but not with torch or jax backend. import os
os.environ["KERAS_BACKEND"] = "jax" # tensorflow, torch, jax
import numpy as np
import tensorflow as tf
import keras
from keras import layers
print(keras.backend.backend())
a = np.ones((4, 50, 50, 3)).astype(np.float32)
b = np.ones((4, 1)).astype(np.float32)
augmentation_layers = keras.Sequential(
[
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.1),
layers.RandomZoom(0.2),
]
)
dataset = tf.data.Dataset.from_tensor_slices((a, b))
dataset = dataset.batch(3, drop_remainder=True)
dataset = dataset.map(
lambda x, y: (augmentation_layers(x), y),
num_parallel_calls=tf.data.AUTOTUNE
)
inputs = keras.Input(shape=(50, 50, 3))
x = keras.layers.Flatten()(inputs)
output = keras.layers.Dense(1)(x)
model = keras.Model(inputs, output)
model.compile(loss='binary_crossentropy')
model.fit(dataset, epochs=1) With torch and jax backend, it gives ---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
<ipython-input-1-0bc36fd74384> in <cell line: 0>()
23 dataset = tf.data.Dataset.from_tensor_slices((a, b))
24 dataset = dataset.batch(3, drop_remainder=True)
---> 25 dataset = dataset.map(
26 lambda x, y: (augmentation_layers(x), y),
27 num_parallel_calls=tf.data.AUTOTUNE
23 frames
/usr/local/lib/python3.11/dist-packages/optree/ops.py in tree_map(func, tree, is_leaf, none_is_leaf, namespace, *rests)
764 leaves, treespec = _C.flatten(tree, is_leaf, none_is_leaf, namespace)
765 flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
--> 766 return treespec.unflatten(map(func, *flat_args))
767
768
NotImplementedError: in user code:
File "<ipython-input-1-0bc36fd74384>", line 26, in None *
lambda x, y: (augmentation_layers(x), y)
File "/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler **
raise e.with_traceback(filtered_tb) from None
File "/usr/local/lib/python3.11/dist-packages/optree/ops.py", line 766, in tree_map
return treespec.unflatten(map(func, *flat_args))
NotImplementedError: Cannot convert a symbolic tf.Tensor (args_0:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported. |
hi. I see a specific test named
|
Like this also seems to work, maybe a bit crappy?? import os
|
@edge7 I understand it as not being officially supported, which makes using |
@edge7 I got a solution, could you please confirm this. For such cases, this can be used https://keras.io/api/layers/preprocessing_layers/image_augmentation/pipeline/ Instead of augmentation_layers = keras.Sequential(
[
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.1),
layers.RandomZoom(0.2),
]
) should do augmentation_layers = keras.layers.Pipeline(
[
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.1),
layers.RandomZoom(0.2),
]
) |
But this got stuck, if we built a custom layer like class RandomColorJitter(keras.layers.Layer):
def __init__(self, brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, **kwargs):
super().__init__(**kwargs)
self.brightness = brightness
self.contrast = contrast
self.saturation = saturation
self.hue = hue
def call(self, inputs, training=None):
if training:
# Apply random transformations only during training
inputs = self.random_brightness(inputs)
inputs = self.random_contrast(inputs)
inputs = self.random_saturation(inputs)
inputs = self.random_hue(inputs)
return inputs
def random_brightness(self, image):
if self.brightness > 0:
# Generate a random brightness factor
factor = ops.random.uniform(
shape=(),
minval=1 - self.brightness,
maxval=1 + self.brightness,
)
image = image * factor
image = ops.clip(image, 0.0, 1.0) # Clip to valid range
return image
def random_contrast(self, image):
if self.contrast > 0:
# Generate a random contrast factor
factor = ops.random.uniform(
shape=(),
minval=1 - self.contrast,
maxval=1 + self.contrast,
)
mean = ops.mean(image, axis=[-3, -2, -1], keepdims=True)
image = (image - mean) * factor + mean
image = ops.clip(image, 0.0, 1.0) # Clip to valid range
return image
def random_saturation(self, image):
if self.saturation > 0:
# Convert RGB to grayscale
grayscale = ops.mean(image, axis=-1, keepdims=True)
# Generate a random saturation factor
factor = ops.random.uniform(
shape=(),
minval=1 - self.saturation,
maxval=1 + self.saturation,
)
image = grayscale + (image - grayscale) * factor
image = ops.clip(image, 0.0, 1.0) # Clip to valid range
return image
def random_hue(self, image):
if self.hue > 0:
# Convert RGB to HSV
image = self.rgb_to_hsv(image)
# Generate a random hue shift
hue_shift = ops.random.uniform(
shape=(),
minval=-self.hue,
maxval=self.hue,
)
image = ops.concatenate(
[
(image[..., 0:1] + hue_shift) % 1.0, # Hue
image[..., 1:2], # Saturation
image[..., 2:3], # Value
],
axis=-1,
)
# Convert back to RGB
image = self.hsv_to_rgb(image)
return image
def rgb_to_hsv(self, image):
return image
def hsv_to_rgb(self, image):
return image
def get_config(self):
config = super().get_config()
config.update({
"brightness": self.brightness,
"contrast": self.contrast,
"saturation": self.saturation,
"hue": self.hue,
})
return config augmentation_layers = keras.layers.Pipeline(
[
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.1),
layers.RandomZoom(0.2),
RandomColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
]
) NotImplementedError: in user code:
File "<ipython-input-6-d080d7f65da4>", line 22, in None *
lambda x, y: (augmentation_layers(x), y)
File "/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler **
raise e.with_traceback(filtered_tb) from None
File "/usr/local/lib/python3.11/dist-packages/optree/ops.py", line 766, in tree_map
return treespec.unflatten(map(func, *flat_args))
File "/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/lax_numpy.py", line 5732, in asarray
return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)
File "/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/lax_numpy.py", line 5538, in array
leaves = [_convert_to_array_if_dtype_fails(leaf) for leaf in leaves]
File "/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/lax_numpy.py", line 5538, in <listcomp>
leaves = [_convert_to_array_if_dtype_fails(leaf) for leaf in leaves]
File "/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/lax_numpy.py", line 5584, in _convert_to_array_if_dtype_fails
return np.asarray(x)
NotImplementedError: Exception encountered when calling Pipeline.call().
Cannot convert a symbolic tf.Tensor (EnsureShape_1:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported.
Arguments received by Pipeline.call():
• inputs=<tf.Tensor 'args_0:0' shape=(3, 50, 50, 3) dtype=float32>
• training=True
• mask=None |
This looks like the official approach, so I'd go for this without doubt.
|
Thanks. Last query. As I was trying use custom layer/ops
|
Say, I have a preprocessing layer, i.e. image resize or any augmentation methods, build with keras.ops. And I use it with tf.data API. Now, with another backend, i.e. torch/jax, I can train the model with tf.data API. Now, having say jax/torch as backend and keras.ops with tf.data api - will it work with model.fit? As using keras.ops with jax/torch backend will call respected operations, will tf.data with those preprocessing layer work as expected?
The text was updated successfully, but these errors were encountered: