Skip to content

Commit 441f63f

Browse files
committed
lint changes in few files
1 parent 5401505 commit 441f63f

File tree

8 files changed

+12
-12
lines changed

8 files changed

+12
-12
lines changed

official/nlp/modeling/layers/transformer_encoder_block.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class RMSNorm(tf_keras.layers.Layer):
2828

2929
def __init__(
3030
self,
31-
axis: Union[int , Sequence[int]] = -1,
31+
axis: Union[int, Sequence[int]] = -1,
3232
epsilon: float = 1e-6,
3333
**kwargs
3434
):
@@ -43,7 +43,8 @@ def __init__(
4343
self.axis = [axis] if isinstance(axis, int) else axis
4444
self.epsilon = epsilon
4545

46-
def build(self, input_shape: Union[tf.TensorShape, Sequence[Union[int, None]]]):
46+
def build(self,
47+
input_shape: Union[tf.TensorShape, Sequence[Union[int, None]]]):
4748
input_shape = tf.TensorShape(input_shape)
4849
scale_shape = [1] * input_shape.rank
4950
for dim in self.axis:

official/recommendation/uplift/layers/uplift_networks/base_uplift_networks.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,12 @@
1515
"""Defines base abstract uplift network layers."""
1616

1717
import abc
18+
from typing import Union
1819

1920
import tensorflow as tf, tf_keras
2021

2122
from official.recommendation.uplift import types
2223

23-
from typing import Union
24-
2524

2625
class BaseTwoTowerUpliftNetwork(tf_keras.layers.Layer, metaclass=abc.ABCMeta):
2726
"""Abstract class for uplift layers that compute control and treatment logits.

official/recommendation/uplift/metrics/label_mean.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414

1515
"""Keras metric for computing the label mean sliced by treatment group."""
1616

17+
from typing import Union
18+
1719
import tensorflow as tf, tf_keras
1820

1921
from official.recommendation.uplift import types
2022
from official.recommendation.uplift.metrics import treatment_sliced_metric
2123

22-
from typing import Union
2324

2425
@tf_keras.utils.register_keras_serializable(package="Uplift")
2526
class LabelMean(tf_keras.metrics.Metric):

official/recommendation/uplift/metrics/label_variance.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
# limitations under the License.
1414

1515
"""Keras metric for computing the label variance sliced by treatment group."""
16+
from typing import Union
1617

1718
import tensorflow as tf, tf_keras
1819

1920
from official.recommendation.uplift import types
2021
from official.recommendation.uplift.metrics import treatment_sliced_metric
2122
from official.recommendation.uplift.metrics import variance
2223

23-
from typing import Union
2424

2525
@tf_keras.utils.register_keras_serializable(package="Uplift")
2626
class LabelVariance(tf_keras.metrics.Metric):

official/recommendation/uplift/metrics/metric_configs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class SlicedMetricConfig(base_config.Config):
3535

3636
slicing_feature: Union[str, None] = None
3737
slicing_spec: Union[Mapping[str, int], None] = None
38-
slicing_feature_dtype: Union[str, None ]= None
38+
slicing_feature_dtype: Union[str, None] = None
3939

4040
def __post_init__(
4141
self, default_params: dict[str, Any], restrictions: list[str]

official/recommendation/uplift/metrics/sliced_metric.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@
1515
"""Keras metric for reporting metrics sliced by a feature."""
1616

1717
import copy
18+
from typing import Union
1819

1920
import tensorflow as tf, tf_keras
2021

21-
from typing import Union
22-
2322

2423
class SlicedMetric(tf_keras.metrics.Metric):
2524
"""A metric sliced by integer, boolean, or string features.

official/recommendation/uplift/metrics/uplift_mean.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414

1515
"""Keras metric for computing the mean uplift sliced by treatment group."""
1616

17+
from typing import Union
18+
1719
import tensorflow as tf, tf_keras
1820

1921
from official.recommendation.uplift import types
2022
from official.recommendation.uplift.metrics import treatment_sliced_metric
2123

22-
from typing import Union
2324

2425
@tf_keras.utils.register_keras_serializable(package="Uplift")
2526
class UpliftMean(tf_keras.metrics.Metric):

official/recommendation/uplift/types.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,11 @@
1313
# limitations under the License.
1414

1515
"""Defines types used by the keras uplift modeling library."""
16+
from typing import Union
1617

1718
import tensorflow as tf, tf_keras
18-
from typing import Union
1919

2020
TensorType = Union[tf.Tensor, tf.SparseTensor, tf.RaggedTensor]
21-
2221
ListOfTensors = list[TensorType]
2322
TupleOfTensors = tuple[TensorType, ...]
2423
DictOfTensors = dict[str, TensorType]

0 commit comments

Comments
 (0)