Skip to content

Commit 59aecc4

Browse files
Fix Ruff rule B008
1 parent 9724828 commit 59aecc4

File tree

3 files changed

+13
-4
lines changed

3 files changed

+13
-4
lines changed

tfx/dsl/component/experimental/decorators_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""Tests for tfx.dsl.components.base.decorators."""
1515

1616

17+
from __future__ import annotations
1718
import pytest
1819
import os
1920
from typing import Any, Dict, List, Optional
@@ -141,8 +142,10 @@ def verify_beam_pipeline_arg(a: int) -> OutputDict(b=float): # pytype: disable=
141142

142143
def verify_beam_pipeline_arg_non_none_default_value(
143144
a: int,
144-
beam_pipeline: BeamComponentParameter[beam.Pipeline] = beam.Pipeline(),
145+
beam_pipeline: BeamComponentParameter[beam.Pipeline] = None,
145146
) -> OutputDict(b=float): # pytype: disable=invalid-annotation,wrong-arg-types
147+
if beam_pipeline is None:
148+
beam_pipeline = beam.Pipeline()
146149
del beam_pipeline
147150
return {'b': float(a)}
148151

tfx/dsl/component/experimental/decorators_typeddict_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""Tests for tfx.dsl.components.base.decorators."""
1515

1616

17+
from __future__ import annotations
1718
import pytest
1819
import os
1920
from typing import Any, Dict, List, Optional, TypedDict
@@ -141,8 +142,10 @@ def verify_beam_pipeline_arg(a: int) -> TypedDict('Output6', dict(b=float)): #
141142

142143
def verify_beam_pipeline_arg_non_none_default_value(
143144
a: int,
144-
beam_pipeline: BeamComponentParameter[beam.Pipeline] = beam.Pipeline(),
145+
beam_pipeline: BeamComponentParameter[beam.Pipeline] | None = None,
145146
) -> TypedDict('Output7', dict(b=float)): # pytype: disable=wrong-arg-types
147+
if beam_pipeline is None:
148+
beam_pipeline = beam.Pipeline()
146149
del beam_pipeline
147150
return {'b': float(a)}
148151

tfx/examples/bert/utils/bert_models.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
"""Configurable fine-tuning BERT models for various tasks."""
1515

16+
from __future__ import annotations
1617
from typing import Optional, List, Union
1718

1819
import tensorflow as tf
@@ -59,8 +60,7 @@ def build_bert_classifier(bert_layer: tf.keras.layers.Layer,
5960

6061
def compile_bert_classifier(
6162
model: tf.keras.Model,
62-
loss: tf.keras.losses.Loss = tf.keras.losses.SparseCategoricalCrossentropy(
63-
from_logits=True),
63+
loss: tf.keras.losses.Loss | None = None,
6464
learning_rate: float = 2e-5,
6565
metrics: Optional[List[Union[str, tf.keras.metrics.Metric]]] = None):
6666
"""Compile the BERT classifier using suggested parameters.
@@ -79,6 +79,9 @@ def compile_bert_classifier(
7979
Returns:
8080
None.
8181
"""
82+
if loss is None:
83+
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
84+
8285
if metrics is None:
8386
metrics = ["sparse_categorical_accuracy"]
8487

0 commit comments

Comments
 (0)