diff --git a/integration_tests/import_test.py b/integration_tests/import_test.py index 25facbc50ce1..b9fd4312153a 100644 --- a/integration_tests/import_test.py +++ b/integration_tests/import_test.py @@ -28,11 +28,11 @@ def setup_package(): whl_path = re.findall( r"[^\s]*\.whl", build_process.stdout, - )[-1] + ) if not whl_path: print(build_process.stderr) raise ValueError("Installing Keras package unsuccessful. ") - return whl_path + return whl_path[-1] def create_virtualenv(): diff --git a/keras/src/backend/jax/image.py b/keras/src/backend/jax/image.py index c167c02787bb..0e82eab53933 100644 --- a/keras/src/backend/jax/image.py +++ b/keras/src/backend/jax/image.py @@ -465,7 +465,7 @@ def affine_transform( # transform the indices coordinates = jnp.einsum("Bhwij, Bjk -> Bhwik", indices, transform) coordinates = jnp.moveaxis(coordinates, source=-1, destination=1) - coordinates += jnp.reshape(a=offset, shape=(*offset.shape, 1, 1, 1)) + coordinates += jnp.reshape(offset, shape=(*offset.shape, 1, 1, 1)) # apply affine transformation _map_coordinates = functools.partial( diff --git a/keras/src/backend/torch/image.py b/keras/src/backend/torch/image.py index 3bf2f2f45d55..d49914559fc7 100644 --- a/keras/src/backend/torch/image.py +++ b/keras/src/backend/torch/image.py @@ -424,7 +424,7 @@ def affine_transform( # transform the indices coordinates = torch.einsum("Bhwij, Bjk -> Bhwik", indices, transform) coordinates = torch.moveaxis(coordinates, source=-1, destination=1) - coordinates += torch.reshape(a=offset, shape=(*offset.shape, 1, 1, 1)) + coordinates += torch.reshape(offset, shape=(*offset.shape, 1, 1, 1)) # Note: torch.stack is faster than torch.vmap when the batch size is small. affined = torch.stack( diff --git a/keras/src/export/onnx.py b/keras/src/export/onnx.py index 4acde7cef98f..391a15c64101 100644 --- a/keras/src/export/onnx.py +++ b/keras/src/export/onnx.py @@ -80,7 +80,7 @@ def export_onnx(model, filepath, verbose=None, input_signature=None, **kwargs): decorated_fn = get_concrete_fn(model, input_signature, **kwargs) # Use `tf2onnx` to convert the `decorated_fn` to the ONNX format. - patch_tf2onnx() # TODO: Remove this once `tf2onnx` supports numpy 2. + # patch_tf2onnx() # TODO: Remove this once `tf2onnx` supports numpy 2. tf2onnx.convert.from_function( decorated_fn, input_signature, output_path=filepath ) diff --git a/requirements-common.txt b/requirements-common.txt index 08d81c03f3d9..5396c2412d1f 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -10,13 +10,14 @@ absl-py requests h5py ml-dtypes -protobuf +protobuf==4.21.6 # Earlier versions break Tensorflow>=2.19 tensorboard-plugin-profile rich build optree pytest-cov packaging +tf2onnx>=1.16.1 # For Numpy 2 support # for tree_test.py dm_tree coverage!=7.6.5 # 7.6.5 breaks CI diff --git a/requirements-jax-cuda.txt b/requirements-jax-cuda.txt index 1cc1a1b75985..4d5a9484d6e3 100644 --- a/requirements-jax-cuda.txt +++ b/requirements-jax-cuda.txt @@ -1,6 +1,5 @@ # Tensorflow cpu-only version (needed for testing). -tensorflow-cpu~=2.18.0 -tf2onnx +tensorflow-cpu # Torch cpu-only version (needed for testing). --extra-index-url https://download.pytorch.org/whl/cpu @@ -8,7 +7,7 @@ torch==2.6.0+cpu # Jax with cuda support. --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -jax[cuda12]==0.4.28 +jax[cuda12] flax -r requirements-common.txt diff --git a/requirements-tensorflow-cuda.txt b/requirements-tensorflow-cuda.txt index dbbf7a3b5106..5d7595c1bfce 100644 --- a/requirements-tensorflow-cuda.txt +++ b/requirements-tensorflow-cuda.txt @@ -1,6 +1,5 @@ # Tensorflow with cuda support. -tensorflow[and-cuda]~=2.18.0 -tf2onnx +tensorflow[and-cuda] # Torch cpu-only version (needed for testing). --extra-index-url https://download.pytorch.org/whl/cpu diff --git a/requirements-torch-cuda.txt b/requirements-torch-cuda.txt index 7b8eb7434a0e..0189e82078e4 100644 --- a/requirements-torch-cuda.txt +++ b/requirements-torch-cuda.txt @@ -1,6 +1,5 @@ # Tensorflow cpu-only version (needed for testing). tensorflow-cpu~=2.18.0 -tf2onnx # Torch with cuda support. # - torch is pinned to a version that is compatible with torch-xla diff --git a/requirements.txt b/requirements.txt index 6c6afff6842a..ebdb655d8e4d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,7 @@ # Tensorflow. -tensorflow-cpu~=2.18.0;sys_platform != 'darwin' -tensorflow~=2.18.0;sys_platform == 'darwin' +tensorflow-cpu~=2.19.0;sys_platform != 'darwin' +tensorflow~=2.19.0;sys_platform == 'darwin' tf_keras -tf2onnx # Torch. --extra-index-url https://download.pytorch.org/whl/cpu @@ -10,9 +9,7 @@ torch==2.6.0+cpu torch-xla==2.6.0;sys_platform != 'darwin' # Jax. -# Pinned to 0.5.0 on CPU. JAX 0.5.1 requires Tensorflow 2.19 for saved_model_test. -# Note that we test against the latest JAX on GPU. -jax[cpu]==0.5.0 +jax[cpu] flax # Common deps.