Skip to content

Commit 88b7328

Browse files
more update
1 parent 0730752 commit 88b7328

File tree

73 files changed

+481
-482
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+481
-482
lines changed

examples/legacy/seq2seq/finetune_trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,9 @@ def main():
231231

232232
# set decoder_start_token_id for MBart
233233
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
234-
assert data_args.tgt_lang is not None and data_args.src_lang is not None, (
235-
"mBart requires --tgt_lang and --src_lang"
236-
)
234+
assert (
235+
data_args.tgt_lang is not None and data_args.src_lang is not None
236+
), "mBart requires --tgt_lang and --src_lang"
237237
if isinstance(tokenizer, MBartTokenizer):
238238
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
239239
else:

examples/legacy/seq2seq/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,9 +283,9 @@ def __init__(self, tokenizer, data_args, decoder_start_token_id, tpu_num_cores=N
283283
self.tokenizer = tokenizer
284284
self.pad_token_id = tokenizer.pad_token_id
285285
self.decoder_start_token_id = decoder_start_token_id
286-
assert self.pad_token_id is not None, (
287-
f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined."
288-
)
286+
assert (
287+
self.pad_token_id is not None
288+
), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined."
289289
self.data_args = data_args
290290
self.tpu_num_cores = tpu_num_cores
291291
self.dataset_kwargs = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}

examples/pytorch/summarization/run_summarization.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -504,9 +504,9 @@ def main():
504504
return
505505

506506
if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)):
507-
assert data_args.lang is not None, (
508-
f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --lang argument"
509-
)
507+
assert (
508+
data_args.lang is not None
509+
), f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --lang argument"
510510

511511
tokenizer.src_lang = data_args.lang
512512
tokenizer.tgt_lang = data_args.lang

examples/pytorch/text-classification/run_classification.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,9 @@ def __post_init__(self):
198198
train_extension = self.train_file.split(".")[-1]
199199
assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file."
200200
validation_extension = self.validation_file.split(".")[-1]
201-
assert validation_extension == train_extension, (
202-
"`validation_file` should have the same extension (csv or json) as `train_file`."
203-
)
201+
assert (
202+
validation_extension == train_extension
203+
), "`validation_file` should have the same extension (csv or json) as `train_file`."
204204

205205

206206
@dataclass
@@ -356,9 +356,9 @@ def main():
356356
if data_args.test_file is not None:
357357
train_extension = data_args.train_file.split(".")[-1]
358358
test_extension = data_args.test_file.split(".")[-1]
359-
assert test_extension == train_extension, (
360-
"`test_file` should have the same extension (csv or json) as `train_file`."
361-
)
359+
assert (
360+
test_extension == train_extension
361+
), "`test_file` should have the same extension (csv or json) as `train_file`."
362362
data_files["test"] = data_args.test_file
363363
else:
364364
raise ValueError("Need either a dataset name or a test file for `do_predict`.")

examples/pytorch/text-classification/run_glue.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,9 @@ def __post_init__(self):
155155
train_extension = self.train_file.split(".")[-1]
156156
assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file."
157157
validation_extension = self.validation_file.split(".")[-1]
158-
assert validation_extension == train_extension, (
159-
"`validation_file` should have the same extension (csv or json) as `train_file`."
160-
)
158+
assert (
159+
validation_extension == train_extension
160+
), "`validation_file` should have the same extension (csv or json) as `train_file`."
161161

162162

163163
@dataclass
@@ -312,9 +312,9 @@ def main():
312312
if data_args.test_file is not None:
313313
train_extension = data_args.train_file.split(".")[-1]
314314
test_extension = data_args.test_file.split(".")[-1]
315-
assert test_extension == train_extension, (
316-
"`test_file` should have the same extension (csv or json) as `train_file`."
317-
)
315+
assert (
316+
test_extension == train_extension
317+
), "`test_file` should have the same extension (csv or json) as `train_file`."
318318
data_files["test"] = data_args.test_file
319319
else:
320320
raise ValueError("Need either a GLUE task or a test file for `do_predict`.")

examples/pytorch/translation/run_translation_no_trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -435,9 +435,9 @@ def main():
435435

436436
# Set decoder_start_token_id
437437
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
438-
assert args.target_lang is not None and args.source_lang is not None, (
439-
"mBart requires --target_lang and --source_lang"
440-
)
438+
assert (
439+
args.target_lang is not None and args.source_lang is not None
440+
), "mBart requires --target_lang and --source_lang"
441441
if isinstance(tokenizer, MBartTokenizer):
442442
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[args.target_lang]
443443
else:

examples/tensorflow/translation/run_translation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -500,9 +500,9 @@ def preprocess_function(examples):
500500

501501
# region Set decoder_start_token_id
502502
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
503-
assert data_args.target_lang is not None and data_args.source_lang is not None, (
504-
"mBart requires --target_lang and --source_lang"
505-
)
503+
assert (
504+
data_args.target_lang is not None and data_args.source_lang is not None
505+
), "mBart requires --target_lang and --source_lang"
506506
if isinstance(tokenizer, MBartTokenizer):
507507
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
508508
else:

src/transformers/cache_utils.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1673,7 +1673,7 @@ def __init__(
16731673
"config and it's not set to None."
16741674
)
16751675
self.config = config
1676-
self.device = device
1676+
self.device = device
16771677
self.layer_device_map = layer_device_map
16781678
self.max_cache_len = max_cache_len
16791679
self.max_batch_size = max_batch_size
@@ -1816,13 +1816,13 @@ def reset(self):
18161816
def _get_flat_dict_for_hybrid_cache(hybrid_cache: HybridCache):
18171817
return {
18181818
"config": getattr(hybrid_cache, "config"),
1819-
"device": str(getattr(hybrid_cache, "device")) if getattr(hybrid_cache, "device", None) != None else None,
1819+
"device": str(getattr(hybrid_cache, "device")) if getattr(hybrid_cache, "device", None) is not None else None,
18201820
"layer_device_map": getattr(hybrid_cache, "layer_device_map"),
18211821
"key_cache": getattr(hybrid_cache, "key_cache"),
18221822
"value_cache": getattr(hybrid_cache, "value_cache"),
18231823
"max_batch_size": getattr(hybrid_cache, "max_batch_size"),
18241824
"max_cache_len": getattr(hybrid_cache, "max_cache_len"),
1825-
"_dtype": str(getattr(hybrid_cache, "_dtype")) if getattr(hybrid_cache, "_dtype", None) != None else None,
1825+
"_dtype": str(getattr(hybrid_cache, "_dtype")) if getattr(hybrid_cache, "_dtype", None) is not None else None,
18261826
}
18271827

18281828

@@ -1833,9 +1833,9 @@ def _flatten_hybrid_cache(
18331833
if not isinstance(hybrid_cache, HybridCache):
18341834
raise RuntimeError("This pytree flattening function should only be applied to HybridCache")
18351835

1836-
if not is_torch_greater_or_equal_than_2_6:
1836+
if not is_torch_greater_or_equal_than_2_7:
18371837
logger.warning_once(
1838-
"HybridCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions."
1838+
"HybridCache + torch.export is tested on torch 2.7.0+ and may not work on earlier versions."
18391839
)
18401840

18411841
return torch.utils._pytree._dict_flatten(_get_flat_dict_for_hybrid_cache(hybrid_cache))
@@ -1851,11 +1851,11 @@ def _unflatten_hybrid_cache(
18511851
):
18521852
dictionary = torch.utils._pytree._dict_unflatten(values, context)
18531853
hybrid_cache = HybridCache(
1854-
dictionary["config"],
1855-
dictionary["max_batch_size"],
1856-
dictionary["max_cache_len"],
1857-
torch.device(dictionary["device"]) if dictionary["device"] != None else None,
1858-
getattr(torch, dictionary["_dtype"][len("torch."):]) if dictionary["_dtype"] != None else None,
1854+
dictionary["config"],
1855+
dictionary["max_batch_size"],
1856+
dictionary["max_cache_len"],
1857+
torch.device(dictionary["device"]) if dictionary["device"] is not None else None,
1858+
getattr(torch, dictionary["_dtype"][len("torch.") :]) if dictionary["_dtype"] is not None else None,
18591859
dictionary["layer_device_map"],
18601860
)
18611861

src/transformers/integrations/tpu.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ def tpu_spmd_dataloader(dataloader: DataLoader):
2121
if is_torch_xla_available():
2222
import torch_xla.distributed.parallel_loader as pl
2323

24-
assert isinstance(dataloader, pl.MpDeviceLoader), (
25-
"The dataloader must be a `torch_xla.distributed.parallel_loader.MpDeviceLoader`."
26-
)
24+
assert isinstance(
25+
dataloader, pl.MpDeviceLoader
26+
), "The dataloader must be a `torch_xla.distributed.parallel_loader.MpDeviceLoader`."
2727

2828
# This is to support PyTorch/XLA FSDP via SPMD.
2929
# Here we shard the input data's 0th dim across the fsdp axis.

src/transformers/modeling_utils.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2542,9 +2542,9 @@ def tie_encoder_to_decoder_recursively(
25422542
total_decoder_name="",
25432543
total_encoder_name="",
25442544
):
2545-
assert isinstance(decoder_pointer, nn.Module) and isinstance(encoder_pointer, nn.Module), (
2546-
f"{decoder_pointer} and {encoder_pointer} have to be of type nn.Module"
2547-
)
2545+
assert isinstance(decoder_pointer, nn.Module) and isinstance(
2546+
encoder_pointer, nn.Module
2547+
), f"{decoder_pointer} and {encoder_pointer} have to be of type nn.Module"
25482548
if hasattr(decoder_pointer, "weight"):
25492549
assert hasattr(encoder_pointer, "weight")
25502550
encoder_pointer.weight = decoder_pointer.weight
@@ -2558,9 +2558,9 @@ def tie_encoder_to_decoder_recursively(
25582558
encoder_modules = encoder_pointer._modules
25592559
decoder_modules = decoder_pointer._modules
25602560
if len(decoder_modules) > 0:
2561-
assert len(encoder_modules) > 0, (
2562-
f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
2563-
)
2561+
assert (
2562+
len(encoder_modules) > 0
2563+
), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
25642564

25652565
all_encoder_weights = {module_name + "/" + sub_name for sub_name in encoder_modules.keys()}
25662566
encoder_layer_pos = 0
@@ -5464,9 +5464,9 @@ def forward(
54645464
Returns:
54655465
`torch.FloatTensor`: The end logits for SQuAD.
54665466
"""
5467-
assert start_states is not None or start_positions is not None, (
5468-
"One of start_states, start_positions should be not None"
5469-
)
5467+
assert (
5468+
start_states is not None or start_positions is not None
5469+
), "One of start_states, start_positions should be not None"
54705470
if start_positions is not None:
54715471
slen, hsz = hidden_states.shape[-2:]
54725472
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
@@ -5536,9 +5536,9 @@ def forward(
55365536
"""
55375537
# No dependency on end_feature so that we can obtain one single `cls_logits` for each sample.
55385538
hsz = hidden_states.shape[-1]
5539-
assert start_states is not None or start_positions is not None, (
5540-
"One of start_states, start_positions should be not None"
5541-
)
5539+
assert (
5540+
start_states is not None or start_positions is not None
5541+
), "One of start_states, start_positions should be not None"
55425542
if start_positions is not None:
55435543
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
55445544
start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)

src/transformers/models/data2vec/convert_data2vec_text_original_pytorch_checkpoint_to_pytorch.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -127,27 +127,27 @@ def convert_data2vec_checkpoint_to_pytorch(
127127

128128
# self-attention output
129129
self_output: BertSelfOutput = layer.attention.output
130-
assert self_output.dense.weight.shape == data2vec_layer.self_attn.out_proj.weight.shape, (
131-
f"Shape for self_output.dense.weight should be {data2vec_layer.self_attn.out_proj.weight.shape}"
132-
)
130+
assert (
131+
self_output.dense.weight.shape == data2vec_layer.self_attn.out_proj.weight.shape
132+
), f"Shape for self_output.dense.weight should be {data2vec_layer.self_attn.out_proj.weight.shape}"
133133
self_output.dense.weight = data2vec_layer.self_attn.out_proj.weight
134134
self_output.dense.bias = data2vec_layer.self_attn.out_proj.bias
135135
self_output.LayerNorm.weight = data2vec_layer.self_attn_layer_norm.weight
136136
self_output.LayerNorm.bias = data2vec_layer.self_attn_layer_norm.bias
137137

138138
# intermediate
139139
intermediate: BertIntermediate = layer.intermediate
140-
assert intermediate.dense.weight.shape == data2vec_layer.fc1.weight.shape, (
141-
f"Shape for intermediate.dense.weight should be {data2vec_layer.fc1.weight.shape}"
142-
)
140+
assert (
141+
intermediate.dense.weight.shape == data2vec_layer.fc1.weight.shape
142+
), f"Shape for intermediate.dense.weight should be {data2vec_layer.fc1.weight.shape}"
143143
intermediate.dense.weight = data2vec_layer.fc1.weight
144144
intermediate.dense.bias = data2vec_layer.fc1.bias
145145

146146
# output
147147
bert_output: BertOutput = layer.output
148-
assert bert_output.dense.weight.shape == data2vec_layer.fc2.weight.shape, (
149-
f"Shape for bert_output.dense.weight should be {data2vec_layer.fc2.weight.shape}"
150-
)
148+
assert (
149+
bert_output.dense.weight.shape == data2vec_layer.fc2.weight.shape
150+
), f"Shape for bert_output.dense.weight should be {data2vec_layer.fc2.weight.shape}"
151151
bert_output.dense.weight = data2vec_layer.fc2.weight
152152
bert_output.dense.bias = data2vec_layer.fc2.bias
153153
bert_output.LayerNorm.weight = data2vec_layer.final_layer_norm.weight

src/transformers/models/deprecated/bort/convert_bort_original_gluonnlp_checkpoint_to_pytorch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,9 @@ def check_and_map_params(hf_param, gluon_param):
180180
gluon_param = to_torch(params[gluon_param])
181181
shape_gluon = gluon_param.shape
182182

183-
assert shape_hf == shape_gluon, (
184-
f"The gluon parameter {gluon_param} has shape {shape_gluon}, but expects shape {shape_hf} for Transformers"
185-
)
183+
assert (
184+
shape_hf == shape_gluon
185+
), f"The gluon parameter {gluon_param} has shape {shape_gluon}, but expects shape {shape_hf} for Transformers"
186186

187187
return gluon_param
188188

src/transformers/models/deprecated/realm/modeling_realm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,9 @@ def load_tf_weights_in_realm(model, config, tf_checkpoint_path):
139139
elif m_name == "kernel":
140140
array = np.transpose(array)
141141
try:
142-
assert pointer.shape == array.shape, (
143-
f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
144-
)
142+
assert (
143+
pointer.shape == array.shape
144+
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
145145
except AssertionError as e:
146146
e.args += (pointer.shape, array.shape)
147147
raise

src/transformers/models/deprecated/transfo_xl/modeling_tf_transfo_xl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,9 +1095,9 @@ def call(
10951095
batch_size, sequence_length = shape_list(input_ids)[:2]
10961096
else:
10971097
batch_size, sequence_length = shape_list(inputs_embeds)[:2]
1098-
assert self.config.pad_token_id is not None or batch_size == 1, (
1099-
"Cannot handle batch sizes > 1 if no padding token is defined."
1100-
)
1098+
assert (
1099+
self.config.pad_token_id is not None or batch_size == 1
1100+
), "Cannot handle batch sizes > 1 if no padding token is defined."
11011101

11021102
if not tf.is_tensor(sequence_lengths):
11031103
in_logits = logits[0:batch_size, sequence_lengths]

src/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,9 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
155155
p_i.data = torch.from_numpy(arr_i)
156156
else:
157157
try:
158-
assert pointer.shape == array.shape, (
159-
f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
160-
)
158+
assert (
159+
pointer.shape == array.shape
160+
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
161161
except AssertionError as e:
162162
e.args += (pointer.shape, array.shape)
163163
raise
@@ -1238,9 +1238,9 @@ def forward(
12381238
else:
12391239
batch_size, sequence_length = inputs_embeds.shape[:2]
12401240

1241-
assert self.config.pad_token_id is not None or batch_size == 1, (
1242-
"Cannot handle batch sizes > 1 if no padding token is defined."
1243-
)
1241+
assert (
1242+
self.config.pad_token_id is not None or batch_size == 1
1243+
), "Cannot handle batch sizes > 1 if no padding token is defined."
12441244
if self.config.pad_token_id is None:
12451245
sequence_lengths = -1
12461246
else:

0 commit comments

Comments
 (0)