Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clarification of using preprocessing with tf.data with multiple backend #21029

Open
pure-rgb opened this issue Mar 13, 2025 · 11 comments
Open

clarification of using preprocessing with tf.data with multiple backend #21029

pure-rgb opened this issue Mar 13, 2025 · 11 comments
Assignees
Labels
type:support User is asking for help / asking an implementation question. Stackoverflow would be better suited.

Comments

@pure-rgb
Copy link

Say, I have a preprocessing layer, i.e. image resize or any augmentation methods, build with keras.ops. And I use it with tf.data API. Now, with another backend, i.e. torch/jax, I can train the model with tf.data API. Now, having say jax/torch as backend and keras.ops with tf.data api - will it work with model.fit? As using keras.ops with jax/torch backend will call respected operations, will tf.data with those preprocessing layer work as expected?

@dhantule dhantule added the type:support User is asking for help / asking an implementation question. Stackoverflow would be better suited. label Mar 17, 2025
@edge7
Copy link
Contributor

edge7 commented Mar 17, 2025

You mean tf.data as data generator?
If so, yes. I have used tf data to feed the model while using jax as backend. But the augmentation part was in pure numpy with albumentations library. I think test your use case should not take much

@pure-rgb
Copy link
Author

@edge7
You didn't understand my question or concern.

Say

[backend: tf] - [dataloader: tf.data + keras.ops] - model.fit/evaluate/predict # OK
[backend: torch] - [dataloader: tf.data + keras.ops] - model.fit/evaluate/predict # is it ok?
[backend: jax] - [dataloader: tf.data + keras.ops] - model.fit/evaluate/predict # is it ok?

@edge7
Copy link
Contributor

edge7 commented Mar 18, 2025

Hi @pure-rgb

That was my understanding:

see this colab it's not what you meant?

@pure-rgb
Copy link
Author

@edge7 you are using tf.py_function to wrap something which I need to work it without adopting any wrapper. If we have to use keras.ops with tf.py_function in tf.data, then there is no point of using tf.data as dataloader as you are sacrificing performance. We can simply use keras.PyDataset.

Now, I was digging the source code and I found that the built in processing layer are using some sort of dynamic backend setting method. And that probably the key here but I'm not sure yet. Take a loot the following example, it works with tensorflow baackend but not with torch or jax backend.

import os
os.environ["KERAS_BACKEND"] = "jax" # tensorflow, torch, jax

import numpy as np

import tensorflow as tf
import keras
from keras import layers
print(keras.backend.backend())

a = np.ones((4, 50, 50, 3)).astype(np.float32)
b = np.ones((4, 1)).astype(np.float32)

augmentation_layers = keras.Sequential(
    [
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
        layers.RandomZoom(0.2),
    ]
)
dataset = tf.data.Dataset.from_tensor_slices((a, b))
dataset = dataset.batch(3, drop_remainder=True)
dataset = dataset.map(
    lambda x, y: (augmentation_layers(x), y), 
    num_parallel_calls=tf.data.AUTOTUNE
)

inputs = keras.Input(shape=(50, 50, 3))
x = keras.layers.Flatten()(inputs)
output = keras.layers.Dense(1)(x)
model = keras.Model(inputs, output)

model.compile(loss='binary_crossentropy')
model.fit(dataset, epochs=1)

With torch and jax backend, it gives

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
<ipython-input-1-0bc36fd74384> in <cell line: 0>()
     23 dataset = tf.data.Dataset.from_tensor_slices((a, b))
     24 dataset = dataset.batch(3, drop_remainder=True)
---> 25 dataset = dataset.map(
     26     lambda x, y: (augmentation_layers(x), y),
     27     num_parallel_calls=tf.data.AUTOTUNE

23 frames
/usr/local/lib/python3.11/dist-packages/optree/ops.py in tree_map(func, tree, is_leaf, none_is_leaf, namespace, *rests)
    764     leaves, treespec = _C.flatten(tree, is_leaf, none_is_leaf, namespace)
    765     flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
--> 766     return treespec.unflatten(map(func, *flat_args))
    767 
    768 

NotImplementedError: in user code:

    File "<ipython-input-1-0bc36fd74384>", line 26, in None  *
        lambda x, y: (augmentation_layers(x), y)
    File "/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler  **
        raise e.with_traceback(filtered_tb) from None
    File "/usr/local/lib/python3.11/dist-packages/optree/ops.py", line 766, in tree_map
        return treespec.unflatten(map(func, *flat_args))

    NotImplementedError: Cannot convert a symbolic tf.Tensor (args_0:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported.

@edge7
Copy link
Contributor

edge7 commented Mar 18, 2025

hi.
Ok, so I am looking at some tests like this one

I see a specific test named test_tf_data_compatibility which is promising.
For instance with just one layer it works fine:

single_layer = layers.RandomFlip("horizontal")
dataset = tf.data.Dataset.from_tensor_slices((a, b))
dataset = dataset.batch(3, drop_remainder=True)
dataset = dataset.map(
               lambda images, labels: (single_layer(images, training=True), labels),
               num_parallel_calls=tf.data.AUTOTUNE
           )


inputs = keras.Input(shape=(50, 50, 3))
x = keras.layers.Flatten()(inputs)
output = keras.layers.Dense(1)(x)
model = keras.Model(inputs, output)

model.compile(loss='binary_crossentropy')
model.fit(dataset, epochs=1)

@edge7
Copy link
Contributor

edge7 commented Mar 18, 2025

Like this also seems to work, maybe a bit crappy??

import os
os.environ["KERAS_BACKEND"] = "jax" # tensorflow, torch, jax

import numpy as np

import tensorflow as tf
import keras
from keras import layers
from keras.src.layers.preprocessing.tf_data_layer import TFDataLayer

print(keras.backend.backend())

a = np.ones((4, 50, 50, 3)).astype(np.float32)
b = np.ones((4, 1)).astype(np.float32)

class ImageAugmentation(TFDataLayer):
    def __init__(self, **kwargs):
        super(ImageAugmentation, self).__init__(**kwargs)
        # Initialize augmentation layers
        self.random_flip = layers.RandomFlip("horizontal")
        self.random_rotation = layers.RandomRotation(0.1)
        self.random_zoom = layers.RandomZoom(0.2)
    
    def call(self, inputs, training=True):
        # Always apply augmentations if training is True
        x = inputs
        if training:
            x = self.random_flip(x)
            x = self.random_rotation(x)
            x = self.random_zoom(x)
        return x

dataset = tf.data.Dataset.from_tensor_slices((a, b))
dataset = dataset.batch(3, drop_remainder=True)
augmentation = ImageAugmentation()

dataset = dataset.map(
               lambda images, labels: (augmentation(images, training=True), labels),
               num_parallel_calls=tf.data.AUTOTUNE
           )

print(keras.backend.backend())

inputs = keras.Input(shape=(50, 50, 3))
x = keras.layers.Flatten()(inputs)
output = keras.layers.Dense(1)(x)
model = keras.Model(inputs, output)

model.compile(loss='binary_crossentropy')
model.fit(dataset, epochs=1)
print(keras.backend.backend())

@pure-rgb
Copy link
Author

@edge7
Thanks for your details investigation. I really appreciate.

I understand it as not being officially supported, which makes using keras.layers with the tf.data API somewhat limited. While this may change in the future, it currently restricts the flexibility of using the tf.data API in certain ways.

@pure-rgb
Copy link
Author

@edge7 I got a solution, could you please confirm this. For such cases, this can be used https://keras.io/api/layers/preprocessing_layers/image_augmentation/pipeline/

Instead of

augmentation_layers = keras.Sequential(
    [
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
        layers.RandomZoom(0.2),
    ]
)

should do

augmentation_layers = keras.layers.Pipeline(
    [
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
        layers.RandomZoom(0.2),
    ]
)

@pure-rgb
Copy link
Author

But this got stuck, if we built a custom layer like

class RandomColorJitter(keras.layers.Layer):
    def __init__(self, brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, **kwargs):
        super().__init__(**kwargs)
        self.brightness = brightness
        self.contrast = contrast
        self.saturation = saturation
        self.hue = hue

    def call(self, inputs, training=None):
        if training:
            # Apply random transformations only during training
            inputs = self.random_brightness(inputs)
            inputs = self.random_contrast(inputs)
            inputs = self.random_saturation(inputs)
            inputs = self.random_hue(inputs)
        return inputs

    def random_brightness(self, image):
        if self.brightness > 0:
            # Generate a random brightness factor
            factor = ops.random.uniform(
                shape=(),
                minval=1 - self.brightness,
                maxval=1 + self.brightness,
            )
            image = image * factor
            image = ops.clip(image, 0.0, 1.0)  # Clip to valid range
        return image

    def random_contrast(self, image):
        if self.contrast > 0:
            # Generate a random contrast factor
            factor = ops.random.uniform(
                shape=(),
                minval=1 - self.contrast,
                maxval=1 + self.contrast,
            )
            mean = ops.mean(image, axis=[-3, -2, -1], keepdims=True)
            image = (image - mean) * factor + mean
            image = ops.clip(image, 0.0, 1.0)  # Clip to valid range
        return image

    def random_saturation(self, image):
        if self.saturation > 0:
            # Convert RGB to grayscale
            grayscale = ops.mean(image, axis=-1, keepdims=True)
            # Generate a random saturation factor
            factor = ops.random.uniform(
                shape=(),
                minval=1 - self.saturation,
                maxval=1 + self.saturation,
            )
            image = grayscale + (image - grayscale) * factor
            image = ops.clip(image, 0.0, 1.0)  # Clip to valid range
        return image

    def random_hue(self, image):
        if self.hue > 0:
            # Convert RGB to HSV
            image = self.rgb_to_hsv(image)
            # Generate a random hue shift
            hue_shift = ops.random.uniform(
                shape=(),
                minval=-self.hue,
                maxval=self.hue,
            )
            image = ops.concatenate(
                [
                    (image[..., 0:1] + hue_shift) % 1.0,  # Hue
                    image[..., 1:2],  # Saturation
                    image[..., 2:3],  # Value
                ],
                axis=-1,
            )
            # Convert back to RGB
            image = self.hsv_to_rgb(image)
        return image

    def rgb_to_hsv(self, image):
        return image

    def hsv_to_rgb(self, image):
        return image

    def get_config(self):
        config = super().get_config()
        config.update({
            "brightness": self.brightness,
            "contrast": self.contrast,
            "saturation": self.saturation,
            "hue": self.hue,
        })
        return config
augmentation_layers = keras.layers.Pipeline(
    [
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
        layers.RandomZoom(0.2),
        RandomColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    ]
)
NotImplementedError: in user code:

    File "<ipython-input-6-d080d7f65da4>", line 22, in None  *
        lambda x, y: (augmentation_layers(x), y)
    File "/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler  **
        raise e.with_traceback(filtered_tb) from None
    File "/usr/local/lib/python3.11/dist-packages/optree/ops.py", line 766, in tree_map
        return treespec.unflatten(map(func, *flat_args))
    File "/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/lax_numpy.py", line 5732, in asarray
        return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)
    File "/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/lax_numpy.py", line 5538, in array
        leaves = [_convert_to_array_if_dtype_fails(leaf) for leaf in leaves]
    File "/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/lax_numpy.py", line 5538, in <listcomp>
        leaves = [_convert_to_array_if_dtype_fails(leaf) for leaf in leaves]
    File "/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/lax_numpy.py", line 5584, in _convert_to_array_if_dtype_fails
        return np.asarray(x)

    NotImplementedError: Exception encountered when calling Pipeline.call().
    
    Cannot convert a symbolic tf.Tensor (EnsureShape_1:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported.
    
    Arguments received by Pipeline.call():
      • inputs=<tf.Tensor 'args_0:0' shape=(3, 50, 50, 3) dtype=float32>
      • training=True
      • mask=None

@edge7
Copy link
Contributor

edge7 commented Mar 22, 2025

@edge7 I got a solution, could you please confirm this. For such cases, this can be used https://keras.io/api/layers/preprocessing_layers/image_augmentation/pipeline/

Instead of

augmentation_layers = keras.Sequential(
[
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.1),
layers.RandomZoom(0.2),
]
)
should do

augmentation_layers = keras.layers.Pipeline(
[
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.1),
layers.RandomZoom(0.2),
]
)

This looks like the official approach, so I'd go for this without doubt.
I think this is the key:

When the layers in the pipeline are compatible with tf.data, the pipeline will also remain tf.data compatible. That is to say, the pipeline will not attempt to convert its inputs to backend-native tensors when in a tf.data context (unlike a Sequential model).

@pure-rgb
Copy link
Author

Thanks. Last query. As I was trying use custom layer/ops RandomColorJitter, where I used keras.ops. I subclass keras.layer, didn't work. I also tried by subclassing TFDataLayer, dind't work. What is the right approach for such case? And what does it mean exactly - what compatibilty is it referring?

When the layers in the pipeline are compatible with tf.data, the pipeline will also remain tf.data compatible.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type:support User is asking for help / asking an implementation question. Stackoverflow would be better suited.
Projects
None yet
Development

No branches or pull requests

4 participants