diff --git a/tensorflow_probability/python/internal/tf_keras.py b/tensorflow_probability/python/internal/tf_keras.py index 5f1cdf4cff..61a5a2755b 100644 --- a/tensorflow_probability/python/internal/tf_keras.py +++ b/tensorflow_probability/python/internal/tf_keras.py @@ -20,8 +20,13 @@ # pylint: disable=g-import-not-at-top # pylint: disable=unused-import # pylint: disable=wildcard-import -_keras_version_fn = getattr(tf.keras, "version", None) -if _keras_version_fn and _keras_version_fn().startswith("3."): +try: + _keras_version_fn = getattr(tf.keras, "version", None) + _use_tf_keras = _keras_version_fn and _keras_version_fn().startswith("3.") + del _keras_version_fn +except ImportError: + _use_tf_keras = True +if _use_tf_keras: from tf_keras import * from tf_keras import __internal__ import tf_keras.api._v1.keras.__internal__.legacy.layers as tf1_layers @@ -35,4 +40,4 @@ del tf1 del tf -del _keras_version_fn +del _use_tf_keras