Skip to content

Commit

Permalink
Remove alias_inplace_update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 684493322
  • Loading branch information
lingvo-bot authored and copybara-github committed Oct 10, 2024
1 parent a396c7f commit 4d51394
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 30 deletions.
3 changes: 2 additions & 1 deletion lingvo/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ lingvo_cuda_py_test(
],
)

pytype_library(
pytype_strict_library(
name = "attention",
srcs = ["attention.py"],
deps = [
Expand All @@ -62,6 +62,7 @@ lingvo_cuda_py_test(
shard_count = 5,
deps = [
":attention",
":layers",
":py_utils",
":quant_utils",
":test_utils",
Expand Down
16 changes: 3 additions & 13 deletions lingvo/core/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
from lingvo.core import symbolic
import numpy as np

from tensorflow.python.ops import inplace_ops # pylint:disable=g-direct-tensorflow-import


# Currently, quantization statistics cannot be accumulated across arbitrary
# defuns, so we allow them to be disabled. A potentially more robust fix is
Expand Down Expand Up @@ -1889,17 +1887,9 @@ def ExtendSourcePacked(
# This could happen in cases where function is called by recurrent.py
# (for example target_sequence_sampler.)
t = tf.reshape(t, [])
# TODO(b/227528061): `alias_inplace_update` is deprecated and has
# non-deterministic results when running on CPU/GPU. Consider
# replacing it with e.g. `tf.tensor_scatter_nd_update`
if py_utils.ReplaceAliasInplaceUpdateInAttention():
extended_packed_src[key] = tf.tensor_scatter_nd_update(
cached_packed_src[key], [[t]], [processed]
)
else:
extended_packed_src[key] = inplace_ops.alias_inplace_update(
cached_packed_src[key], t, processed
)
extended_packed_src[key] = tf.tensor_scatter_nd_update(
cached_packed_src[key], [[t]], [processed]
)
else:
processed = tf.reshape(processed_packed_src[key], [1, batch_size, -1])
extended_packed_src[key] = tf.concat(
Expand Down
16 changes: 0 additions & 16 deletions lingvo/core/py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,22 +726,6 @@ def SetProcessFPropResultsInEager(process_fprop_results=True):
_RUN_PROCESS_FPROP_IN_EAGER = process_fprop_results


# TODO(b/227528061): `alias_inplace_update` is deprecated and has
# non-deterministic results when running on CPU/GPU. We have observed incorrect
# results when running the op in TF2 on CPU.
# Used for tests only.
_REPLACE_ALIAS_INPLACE_UPDATE_IN_ATTENTION = False


def SetReplaceAliasInplaceUpdateInAttention(replace_op=True):
global _REPLACE_ALIAS_INPLACE_UPDATE_IN_ATTENTION
_REPLACE_ALIAS_INPLACE_UPDATE_IN_ATTENTION = replace_op


def ReplaceAliasInplaceUpdateInAttention():
return _REPLACE_ALIAS_INPLACE_UPDATE_IN_ATTENTION


# Defaults to True.
# Set to false for already existing Eager mode tests.
_EAGER_RNG_ADAPTATION = True
Expand Down

0 comments on commit 4d51394

Please sign in to comment.