Skip to content
This repository was archived by the owner on Nov 16, 2023. It is now read-only.

Commit 7ce3ada

Browse files
committed
fix test
1 parent 41a7d9b commit 7ce3ada

File tree

2 files changed

+17
-16
lines changed

2 files changed

+17
-16
lines changed

tests/unit/test_abstractive_summarization_seq2seq.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,9 @@ def s2s_test_data():
8484

8585
@pytest.mark.gpu
8686
def test_S2SAbstractiveSummarizer(s2s_test_data, tmp):
87-
processor = S2SAbsSumProcessor(cache_dir=tmp)
87+
cache_dir = tmp
88+
model_dir = tmp
89+
processor = S2SAbsSumProcessor(cache_dir=cache_dir)
8890
train_dataset = processor.s2s_dataset_from_json_or_file(
8991
s2s_test_data["train_ds"], train_mode=True
9092
)
@@ -95,14 +97,14 @@ def test_S2SAbstractiveSummarizer(s2s_test_data, tmp):
9597
max_seq_length=MAX_SEQ_LENGTH,
9698
max_source_seq_length=MAX_SOURCE_SEQ_LENGTH,
9799
max_target_seq_length=MAX_TARGET_SEQ_LENGTH,
98-
cache_dir=tmp,
100+
cache_dir=cache_dir,
99101
)
100102

101103
# test fit and predict
102-
abs_summarizer.fit(
104+
global_step = abs_summarizer.fit(
103105
train_dataset,
104106
per_gpu_batch_size=TRAIN_PER_GPU_BATCH_SIZE,
105-
save_model_to_dir=tmp,
107+
save_model_to_dir=model_dir,
106108
)
107109
abs_summarizer.predict(
108110
test_dataset,
@@ -112,12 +114,12 @@ def test_S2SAbstractiveSummarizer(s2s_test_data, tmp):
112114

113115
# test load model from local disk
114116
abs_summarizer_loaded = S2SAbstractiveSummarizer(
115-
load_model_from_dir=tmp,
116-
model_file_name="model.1.bin",
117+
load_model_from_dir=model_dir,
118+
model_file_name="model.{}.bin".format(global_step),
117119
max_seq_length=MAX_SEQ_LENGTH,
118120
max_source_seq_length=MAX_SOURCE_SEQ_LENGTH,
119121
max_target_seq_length=MAX_TARGET_SEQ_LENGTH,
120-
cache_dir=tmp,
122+
cache_dir=cache_dir,
121123
)
122124

123125
abs_summarizer_loaded.predict(
@@ -130,10 +132,10 @@ def test_S2SAbstractiveSummarizer(s2s_test_data, tmp):
130132
abs_summarizer.fit(
131133
train_dataset,
132134
per_gpu_batch_size=TRAIN_PER_GPU_BATCH_SIZE,
133-
save_model_to_dir=tmp,
134-
recover_step=1,
135-
recover_dir=tmp,
136-
max_steps=4,
135+
save_model_to_dir=model_dir,
136+
recover_step=global_step,
137+
recover_dir=model_dir,
138+
max_steps=global_step + 3,
137139
)
138140

139141
abs_summarizer.predict(

utils_nlp/models/transformers/abstractive_summarization_seq2seq.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,7 @@ def __init__(
524524
+ self.max_target_seq_length,
525525
)
526526
logger.info("Model config for seq2seq: %s", str(config))
527-
527+
528528
self.model = model_class.from_pretrained(
529529
model_to_load,
530530
config=config,
@@ -732,11 +732,12 @@ def fit(
732732
)
733733

734734
if save_model_to_dir is not None and local_rank in [-1, 0]:
735-
self.save_model(save_model_to_dir, global_step, fp16)
735+
self.save_model(save_model_to_dir, global_step - 1, fp16)
736736

737737
# release GPU memories
738738
self.model.cpu()
739739
torch.cuda.empty_cache()
740+
return global_step - 1
740741

741742
def predict(
742743
self,
@@ -896,8 +897,7 @@ def collate_fn(input_batch):
896897
is_roberta=is_roberta,
897898
no_segment_embedding=no_segment_embedding
898899
)
899-
# print(self._bert_model_name)
900-
# print(type(bert_config))
900+
901901
model = BertForSeq2SeqDecoder.from_pretrained(
902902
self._bert_model_name,
903903
bert_config,
@@ -955,7 +955,6 @@ def collate_fn(input_batch):
955955
batch_size=batch_size,
956956
collate_fn=collate_fn,
957957
)
958-
print(device)
959958
for batch, buf_id in tqdm(
960959
test_dataloader, desc="Evaluating", disable=not verbose
961960
):

0 commit comments

Comments
 (0)