File tree Expand file tree Collapse file tree 3 files changed +5
-30
lines changed Expand file tree Collapse file tree 3 files changed +5
-30
lines changed Original file line number Diff line number Diff line change @@ -41,7 +41,7 @@ lingvo_cuda_py_test(
41
41
],
42
42
)
43
43
44
- pytype_library (
44
+ pytype_strict_library (
45
45
name = "attention" ,
46
46
srcs = ["attention.py" ],
47
47
deps = [
@@ -62,6 +62,7 @@ lingvo_cuda_py_test(
62
62
shard_count = 5 ,
63
63
deps = [
64
64
":attention" ,
65
+ ":layers" ,
65
66
":py_utils" ,
66
67
":quant_utils" ,
67
68
":test_utils" ,
Original file line number Diff line number Diff line change 26
26
from lingvo .core import symbolic
27
27
import numpy as np
28
28
29
- from tensorflow .python .ops import inplace_ops # pylint:disable=g-direct-tensorflow-import
30
-
31
29
32
30
# Currently, quantization statistics cannot be accumulated across arbitrary
33
31
# defuns, so we allow them to be disabled. A potentially more robust fix is
@@ -1889,17 +1887,9 @@ def ExtendSourcePacked(
1889
1887
# This could happen in cases where function is called by recurrent.py
1890
1888
# (for example target_sequence_sampler.)
1891
1889
t = tf .reshape (t , [])
1892
- # TODO(b/227528061): `alias_inplace_update` is deprecated and has
1893
- # non-deterministic results when running on CPU/GPU. Consider
1894
- # replacing it with e.g. `tf.tensor_scatter_nd_update`
1895
- if py_utils .ReplaceAliasInplaceUpdateInAttention ():
1896
- extended_packed_src [key ] = tf .tensor_scatter_nd_update (
1897
- cached_packed_src [key ], [[t ]], [processed ]
1898
- )
1899
- else :
1900
- extended_packed_src [key ] = inplace_ops .alias_inplace_update (
1901
- cached_packed_src [key ], t , processed
1902
- )
1890
+ extended_packed_src [key ] = tf .tensor_scatter_nd_update (
1891
+ cached_packed_src [key ], [[t ]], [processed ]
1892
+ )
1903
1893
else :
1904
1894
processed = tf .reshape (processed_packed_src [key ], [1 , batch_size , - 1 ])
1905
1895
extended_packed_src [key ] = tf .concat (
Original file line number Diff line number Diff line change @@ -726,22 +726,6 @@ def SetProcessFPropResultsInEager(process_fprop_results=True):
726
726
_RUN_PROCESS_FPROP_IN_EAGER = process_fprop_results
727
727
728
728
729
- # TODO(b/227528061): `alias_inplace_update` is deprecated and has
730
- # non-deterministic results when running on CPU/GPU. We have observed incorrect
731
- # results when running the op in TF2 on CPU.
732
- # Used for tests only.
733
- _REPLACE_ALIAS_INPLACE_UPDATE_IN_ATTENTION = False
734
-
735
-
736
- def SetReplaceAliasInplaceUpdateInAttention (replace_op = True ):
737
- global _REPLACE_ALIAS_INPLACE_UPDATE_IN_ATTENTION
738
- _REPLACE_ALIAS_INPLACE_UPDATE_IN_ATTENTION = replace_op
739
-
740
-
741
- def ReplaceAliasInplaceUpdateInAttention ():
742
- return _REPLACE_ALIAS_INPLACE_UPDATE_IN_ATTENTION
743
-
744
-
745
729
# Defaults to True.
746
730
# Set to false for already existing Eager mode tests.
747
731
_EAGER_RNG_ADAPTATION = True
You can’t perform that action at this time.
0 commit comments