diff --git a/tensorflow_probability/python/internal/backend/numpy/misc.py b/tensorflow_probability/python/internal/backend/numpy/misc.py index c6977864d0..d9f1ecdd5c 100644 --- a/tensorflow_probability/python/internal/backend/numpy/misc.py +++ b/tensorflow_probability/python/internal/backend/numpy/misc.py @@ -118,11 +118,13 @@ def _sort(values, axis=-1, direction='ASCENDING', name=None): # pylint: disable # TODO(b/140685491): Add unit-test. -def _tensor_scatter_nd_add(tensor, indices, updates, name=None): # pylint: disable=unused-argument +def _tensor_scatter_nd_add( + tensor, indices, updates, bad_indices_policy='', name=None): # pylint: disable=unused-argument """Numpy implementation of `tf.tensor_scatter_nd_add`.""" indices = _convert_to_tensor(indices) tensor = _convert_to_tensor(tensor) updates = _convert_to_tensor(updates) + del bad_indices_policy indices = tuple( indices[..., i] for i in range(indices.shape[-1])) # TODO(b/140685491) if JAX_MODE: @@ -132,11 +134,13 @@ def _tensor_scatter_nd_add(tensor, indices, updates, name=None): # pylint: disa # TODO(b/140685491): Add unit-test. -def _tensor_scatter_nd_sub(tensor, indices, updates, name=None): # pylint: disable=unused-argument +def _tensor_scatter_nd_sub( + tensor, indices, updates, bad_indices_policy='', name=None): # pylint: disable=unused-argument """Numpy implementation of `tf.tensor_scatter_nd_sub`.""" indices = _convert_to_tensor(indices) tensor = _convert_to_tensor(tensor) updates = _convert_to_tensor(updates) + del bad_indices_policy indices = tuple( indices[..., i] for i in range(indices.shape[-1])) # TODO(b/140685491) if JAX_MODE: @@ -146,11 +150,13 @@ def _tensor_scatter_nd_sub(tensor, indices, updates, name=None): # pylint: disa # TODO(b/140685491): Add unit-test. -def _tensor_scatter_nd_update(tensor, indices, updates, name=None): # pylint: disable=unused-argument +def _tensor_scatter_nd_update( + tensor, indices, updates, bad_indices_policy='', name=None): # pylint: disable=unused-argument """Numpy implementation of `tf.tensor_scatter_nd_update`.""" indices = _convert_to_tensor(indices) tensor = _convert_to_tensor(tensor) updates = _convert_to_tensor(updates) + del bad_indices_policy indices = tuple( indices[..., i] for i in range(indices.shape[-1])) # TODO(b/140685491) if JAX_MODE: