Skip to content

Commit 4d51394

Browse files
lingvo-botcopybara-github
authored andcommitted
Remove alias_inplace_update
PiperOrigin-RevId: 684493322
1 parent a396c7f commit 4d51394

File tree

3 files changed

+5
-30
lines changed

3 files changed

+5
-30
lines changed

lingvo/core/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ lingvo_cuda_py_test(
4141
],
4242
)
4343

44-
pytype_library(
44+
pytype_strict_library(
4545
name = "attention",
4646
srcs = ["attention.py"],
4747
deps = [
@@ -62,6 +62,7 @@ lingvo_cuda_py_test(
6262
shard_count = 5,
6363
deps = [
6464
":attention",
65+
":layers",
6566
":py_utils",
6667
":quant_utils",
6768
":test_utils",

lingvo/core/attention.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@
2626
from lingvo.core import symbolic
2727
import numpy as np
2828

29-
from tensorflow.python.ops import inplace_ops # pylint:disable=g-direct-tensorflow-import
30-
3129

3230
# Currently, quantization statistics cannot be accumulated across arbitrary
3331
# defuns, so we allow them to be disabled. A potentially more robust fix is
@@ -1889,17 +1887,9 @@ def ExtendSourcePacked(
18891887
# This could happen in cases where function is called by recurrent.py
18901888
# (for example target_sequence_sampler.)
18911889
t = tf.reshape(t, [])
1892-
# TODO(b/227528061): `alias_inplace_update` is deprecated and has
1893-
# non-deterministic results when running on CPU/GPU. Consider
1894-
# replacing it with e.g. `tf.tensor_scatter_nd_update`
1895-
if py_utils.ReplaceAliasInplaceUpdateInAttention():
1896-
extended_packed_src[key] = tf.tensor_scatter_nd_update(
1897-
cached_packed_src[key], [[t]], [processed]
1898-
)
1899-
else:
1900-
extended_packed_src[key] = inplace_ops.alias_inplace_update(
1901-
cached_packed_src[key], t, processed
1902-
)
1890+
extended_packed_src[key] = tf.tensor_scatter_nd_update(
1891+
cached_packed_src[key], [[t]], [processed]
1892+
)
19031893
else:
19041894
processed = tf.reshape(processed_packed_src[key], [1, batch_size, -1])
19051895
extended_packed_src[key] = tf.concat(

lingvo/core/py_utils.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -726,22 +726,6 @@ def SetProcessFPropResultsInEager(process_fprop_results=True):
726726
_RUN_PROCESS_FPROP_IN_EAGER = process_fprop_results
727727

728728

729-
# TODO(b/227528061): `alias_inplace_update` is deprecated and has
730-
# non-deterministic results when running on CPU/GPU. We have observed incorrect
731-
# results when running the op in TF2 on CPU.
732-
# Used for tests only.
733-
_REPLACE_ALIAS_INPLACE_UPDATE_IN_ATTENTION = False
734-
735-
736-
def SetReplaceAliasInplaceUpdateInAttention(replace_op=True):
737-
global _REPLACE_ALIAS_INPLACE_UPDATE_IN_ATTENTION
738-
_REPLACE_ALIAS_INPLACE_UPDATE_IN_ATTENTION = replace_op
739-
740-
741-
def ReplaceAliasInplaceUpdateInAttention():
742-
return _REPLACE_ALIAS_INPLACE_UPDATE_IN_ATTENTION
743-
744-
745729
# Defaults to True.
746730
# Set to false for already existing Eager mode tests.
747731
_EAGER_RNG_ADAPTATION = True

0 commit comments

Comments
 (0)