-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
106 lines (87 loc) · 2.71 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import pytorch_lightning as pl
from transformers import RobertaTokenizer
from factories.ModelFactory import ModelFactory
from transformers import (
RobertaTokenizer,
RobertaModel,
)
import torch
import torch.multiprocessing as mp
from pytorch_lightning.loggers import TensorBoardLogger
from transformers import RobertaConfig, RobertaModel
from dataloader.MyDataloader import MyDataloader
from utils.utils import save_results
from pytorch_lightning.callbacks import EarlyStopping
early_stop_callback = EarlyStopping(
monitor="val_loss", min_delta=0.00, patience=5, verbose=False, mode="min"
)
torch.set_float32_matmul_precision("medium")
input_dim = 100
emb_dim = 768
mlp_dims = [512, 256]
lr = 0.0001
dropout = 0.2
weight_decay = 0.001
save_param_dir = "./params"
max_len = 170
epochs = 50
# epochs = 5
batch_size = 64
subset_size = 128
# subset_size = None
category_dict = {
"gossipcop": 0,
"politifact": 1,
"COVID": 2,
}
num_workers = 3
train_path = "./data/en/train.pkl"
val_path = "./data/en/val.pkl"
test_path = "./data/en/test.pkl"
if __name__ == "__main__":
if not os.path.exists(save_param_dir):
os.makedirs(save_param_dir)
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
bert = RobertaModel.from_pretrained("roberta-base").requires_grad_(False)
loader = MyDataloader(
max_len=max_len,
batch_size=batch_size,
subset_size=subset_size,
category_dict=category_dict,
num_workers=num_workers,
tokenizer=tokenizer,
)
train_loader = loader.load_data(train_path, True)
val_loader = loader.load_data(val_path, False)
test_loader = loader.load_data(test_path, False)
model_name = "M3FEND"
model, callback = ModelFactory(
emb_dim=emb_dim,
mlp_dims=mlp_dims,
lr=lr,
dropout=dropout,
category_dict=category_dict,
weight_decay=weight_decay,
save_param_dir=save_param_dir,
bert=bert,
train_loader=train_loader,
).create_model(model_name)
callbacks = []
if callback is not None:
callbacks.append(callback)
logger = TensorBoardLogger(save_dir="logs", name="single_runs", version=model_name)
early_stop_callback = EarlyStopping(
monitor="val_loss", min_delta=0.00, patience=5, verbose=False, mode="min"
)
callbacks.append(early_stop_callback)
trainer = pl.Trainer(
max_epochs=epochs,
accelerator="gpu",
logger=logger,
callbacks=callbacks,
)
trainer.fit(model, train_loader, val_loader)
result = trainer.test(model, dataloaders=test_loader)
print("Results:", result[0])
save_results("single_research_results", model_name=model_name, results=result[0])