Skip to content

Commit

Permalink
Revert "Change references of tf.__internal__.saved_model.StructureCod…
Browse files Browse the repository at this point in the history
…er to their replacement functions"

This reverts commit 1020caa for compatibility of TFP 0.15 with TensorFlow 2.7.0.
  • Loading branch information
emilyfertig committed Nov 16, 2021
1 parent 76eaa7e commit 286af5b
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 10 deletions.
5 changes: 3 additions & 2 deletions tensorflow_probability/python/bijectors/blockwise_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,9 @@ def call_forward(bij, x):
self.assertAllClose(call_forward(unflat, x), blockwise.forward(x))

# Type spec can be encoded/decoded.
enc = tf.__internal__.saved_model.encode_structure(blockwise._type_spec)
dec = tf.__internal__.saved_model.decode_proto(enc)
struct_coder = tf.__internal__.saved_model.StructureCoder()
enc = struct_coder.encode_structure(blockwise._type_spec)
dec = struct_coder.decode_proto(enc)
self.assertEqual(blockwise._type_spec, dec)

def testNonCompositeTensor(self):
Expand Down
5 changes: 3 additions & 2 deletions tensorflow_probability/python/bijectors/chain_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,9 @@ def call_forward(bij, x):
self.assertAllClose(call_forward(unflat, x), chain.forward(x))

# TypeSpec can be encoded/decoded.
enc = tf.__internal__.saved_model.encode_structure(chain._type_spec)
dec = tf.__internal__.saved_model.decode_proto(enc)
struct_coder = tf.__internal__.saved_model.StructureCoder()
enc = struct_coder.encode_structure(chain._type_spec)
dec = struct_coder.decode_proto(enc)
self.assertEqual(chain._type_spec, dec)

def testNonCompositeTensor(self):
Expand Down
5 changes: 3 additions & 2 deletions tensorflow_probability/python/bijectors/joint_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,9 @@ def call_forward(bij, x):
self.assertAllCloseNested(call_forward(unflat, x), bij.forward(x))

# Type spec can be encoded/decoded.
enc = tf.__internal__.saved_model.encode_structure(bij._type_spec)
dec = tf.__internal__.saved_model.decode_proto(enc)
struct_coder = tf.__internal__.saved_model.StructureCoder()
enc = struct_coder.encode_structure(bij._type_spec)
dec = struct_coder.decode_proto(enc)
self.assertEqual(bij._type_spec, dec)

def testNonCompositeTensor(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,9 @@ def composite_helper(v):
mk_err_msg('(Unable to convert dependent entry \'{}\' of object '
'\'{}\': {})'.format(k, obj, str(e))))
result = cls(**kwargs)
struct_coder = nested_structure_coder.StructureCoder()
try:
nested_structure_coder.encode_structure(result._type_spec) # pylint: disable=protected-access
struct_coder.encode_structure(result._type_spec) # pylint: disable=protected-access
except nested_structure_coder.NotEncodableError as e:
raise NotImplementedError(
mk_err_msg('(Unable to serialize: {})'.format(str(e))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,9 @@ def test_save_restore_functor(self):
a = tf.constant([3., 2.])
ct = ThingWithCallableArg(a, f=f)

struct_coder = tf.__internal__.saved_model.StructureCoder()
with self.assertRaisesRegex(ValueError, 'Cannot serialize'):
tf.__internal__.saved_model.encode_structure(ct._type_spec) # pylint: disable=protected-access
struct_coder.encode_structure(ct._type_spec) # pylint: disable=protected-access

@tfp.experimental.auto_composite_tensor(module_name='my.module')
class F(tfp.experimental.AutoCompositeTensor):
Expand All @@ -372,8 +373,8 @@ def __call__(self, *args, **kwargs):
return f(*args, **kwargs)

ct_functor = ThingWithCallableArg(a, f=F())
enc = tf.__internal__.saved_model.encode_structure(ct_functor._type_spec)
dec = tf.__internal__.saved_model.decode_proto(enc)
enc = struct_coder.encode_structure(ct_functor._type_spec)
dec = struct_coder.decode_proto(enc)
self.assertEqual(dec, ct_functor._type_spec)

def test_composite_tensor_callable_arg(self):
Expand Down

0 comments on commit 286af5b

Please sign in to comment.