diff --git a/export/orbax/export/data_processors/tf_data_processor_test.py b/export/orbax/export/data_processors/tf_data_processor_test.py index c21580f2b..3d274ad04 100644 --- a/export/orbax/export/data_processors/tf_data_processor_test.py +++ b/export/orbax/export/data_processors/tf_data_processor_test.py @@ -56,26 +56,23 @@ def test_output_signature_raises_error_without_calling_prepare(self): _ = processor.output_signature def test_prepare_fails_with_multiple_calls(self): - processor = tf_data_processor.TfDataProcessor(lambda x: x) + processor = tf_data_processor.TfDataProcessor(lambda x: x, name='identity') processor.prepare( - 'add', - input_signature=(tf.TensorSpec([None, 3], tf.float32),), + (tf.TensorSpec([None, 3], tf.float32),), ) with self.assertRaisesWithLiteralMatch( RuntimeError, '`prepare()` can only be called once.' ): processor.prepare( - 'add', - input_signature=(tf.TensorSpec([None, 3], tf.float32),), + (tf.TensorSpec([None, 3], tf.float32),), ) def test_prepare_succeeds(self): processor = tf_data_processor.TfDataProcessor( - tf.function(lambda x, y: x + y) + tf.function(lambda x, y: x + y), name='add' ) processor.prepare( - 'add', - input_signature=( + ( tf.TensorSpec([None, 3], tf.float64), tf.TensorSpec([None, 3], tf.float64), ), @@ -107,10 +104,11 @@ def test_prepare_polymorphic_function_with_default_input_signature(self): def preprocessor_callable(x, y): return x + y - processor = tf_data_processor.TfDataProcessor(preprocessor_callable) + processor = tf_data_processor.TfDataProcessor( + preprocessor_callable, name='add' + ) processor.prepare( - 'add', - input_signature=( + ( tf.TensorSpec([None, 3], tf.float32), tf.TensorSpec([None, 3], tf.float32), ), @@ -136,7 +134,8 @@ def test_suppress_x64_output(self): processor = tf_data_processor.TfDataProcessor( tf.function( lambda x, y: tf.cast(x, tf.float64) + tf.cast(y, tf.float64) - ) + ), + name='add_f64', ) input_signature = ( tf.TensorSpec([None, 3], tf.float32), @@ -144,17 +143,18 @@ def test_suppress_x64_output(self): ) # With suppress_x64_output=True, f64 output is suppressed to f32. - processor.prepare('add_f64', input_signature, suppress_x64_output=True) + processor.prepare(input_signature, suppress_x64_output=True) self.assertEqual( processor.output_signature, obm.ShloTensorSpec(shape=(None, 3), dtype=obm.ShloDType.f32), ) def test_convert_to_bfloat16(self): - processor = tf_data_processor.TfDataProcessor(lambda x: 0.5 + x) + processor = tf_data_processor.TfDataProcessor( + lambda x: 0.5 + x, name='preprocessor' + ) processor.prepare( - 'preprocessor', - input_signature=(tf.TensorSpec((), tf.float32)), + (tf.TensorSpec((), tf.float32)), bfloat16_options=converter_options_v2_pb2.ConverterOptionsV2( bfloat16_optimization_options=converter_options_v2_pb2.BFloat16OptimizationOptions( scope=converter_options_v2_pb2.BFloat16OptimizationOptions.ALL, @@ -168,15 +168,16 @@ def test_convert_to_bfloat16(self): ) def test_bfloat16_convert_error(self): - processor = tf_data_processor.TfDataProcessor(lambda x: 0.5 + x) + processor = tf_data_processor.TfDataProcessor( + lambda x: 0.5 + x, name='preprocessor' + ) with self.assertRaisesRegex( google_error.StatusNotOk, 'Found bfloat16 ops in the model. The model may have been converted' ' before. It should not be converted again.', ): processor.prepare( - 'preprocessor', - input_signature=(tf.TensorSpec((), tf.bfloat16)), + (tf.TensorSpec((), tf.bfloat16)), bfloat16_options=converter_options_v2_pb2.ConverterOptionsV2( bfloat16_optimization_options=converter_options_v2_pb2.BFloat16OptimizationOptions( scope=converter_options_v2_pb2.BFloat16OptimizationOptions.ALL, @@ -185,12 +186,9 @@ def test_bfloat16_convert_error(self): ) def test_prepare_with_shlo_bf16_inputs(self): - processor = tf_data_processor.TfDataProcessor(lambda x: x) + processor = tf_data_processor.TfDataProcessor(lambda x: x, name='identity') processor.prepare( - 'identity', - input_signature=( - obm.ShloTensorSpec(shape=(1,), dtype=obm.ShloDType.bf16), - ), + (obm.ShloTensorSpec(shape=(1,), dtype=obm.ShloDType.bf16),), ) self.assertEqual( processor.concrete_function.structured_input_signature[0][0].dtype, diff --git a/export/orbax/export/oex_orchestration.py b/export/orbax/export/oex_orchestration.py index a26fa53e7..95ccde5c3 100644 --- a/export/orbax/export/oex_orchestration.py +++ b/export/orbax/export/oex_orchestration.py @@ -14,7 +14,11 @@ """Pipeline: pre-processor + model-function + post-processor.""" +import dataclasses from typing import Any, Dict, List, Sequence, Tuple, TypeVar from absl import logging import jax +import jaxtyping +from orbax.export.data_processors import data_processor_base +from orbax.export.modules import obm_module diff --git a/export/orbax/export/oex_orchestration_test.py b/export/orbax/export/oex_orchestration_test.py index 2e7de85e8..8a6adf4c5 100644 --- a/export/orbax/export/oex_orchestration_test.py +++ b/export/orbax/export/oex_orchestration_test.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock + from absl.testing import absltest from absl.testing import parameterized from orbax.export import oex_orchestration +from orbax.export.data_processors import data_processor_base +from orbax.export.modules import obm_module import tensorflow as tf @@ -23,6 +27,12 @@ def tf_fn(a): return a +class TestProcessor(data_processor_base.DataProcessor): + + def prepare(self, input_signature): + pass + + class OexOrchestrationTest(parameterized.TestCase): # Dummy test to make copybara happy, will be removed once all the obm # dependencies are OSSed.