Skip to content

Commit 07d004a

Browse files
committed
Working on test loop
Finished model training and accuracy
1 parent 6edbb72 commit 07d004a

File tree

3 files changed

+56
-6
lines changed

3 files changed

+56
-6
lines changed

model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def forward(self, x):
5050
x = x.permute(1, 0, 2)
5151
x = x[:, -1, :]
5252
x = self.classifier(x)
53-
return x
53+
return nn.functional.softmax(x, dim=1)
5454

5555
def training_step(self, batch, batch_idx):
5656
x, y = batch
@@ -72,7 +72,9 @@ def test_step(self, batch, batch_idx):
7272
x, y = batch
7373
logits = self.forward(x)
7474
loss = nn.functional.cross_entropy(logits, y)
75-
acc = self.accuracy(logits, y)
75+
pred = torch.argmax(logits, dim=1)
76+
y = torch.argmax(y, dim=1)
77+
acc = self.accuracy(pred, y)
7678
self.log('test_loss', loss, on_epoch=True, on_step=False)
7779
self.log('test_accuracy', acc, on_epoch=True, on_step=False)
7880
return loss

test.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,41 @@
1+
import lightning
2+
import torch
3+
from torch.utils.data import DataLoader
4+
from model import MultiClassAnomaly
5+
from dataset import SeriesDataset
16

7+
8+
def test(args):
9+
checkpoint = torch.load(args.checkpoint_path)
10+
hyperparams = checkpoint['hyperparameters']
11+
12+
model_args = {
13+
'input_size': hyperparams['input_size'],
14+
'hidden_size': hyperparams['hidden_size'],
15+
'num_heads': hyperparams['num_heads'],
16+
'num_layers': hyperparams['num_layers'],
17+
'num_classes': hyperparams['num_classes'],
18+
}
19+
20+
model = MultiClassAnomaly(**model_args)
21+
model.load_state_dict(checkpoint['state_dict'])
22+
model.eval()
23+
24+
dataset = SeriesDataset(args.data_dir)
25+
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)
26+
27+
trainer = lightning.Trainer(default_root_dir=args['checkpoint_dir'], max_epochs=args['epochs'])
28+
trainer.test(model=model, dataloaders=dataloader)
29+
30+
31+
if __name__ == '__main__':
32+
import argparse
33+
34+
parser = argparse.ArgumentParser(description='Test trained model on the entire dataset')
35+
parser.add_argument('--checkpoint_path', type=str, default="checkpoints/model_checkpoint.pth",
36+
help='Path to the model checkpoint file')
37+
parser.add_argument('--batch_size', type=int, default=8, help='Batch size for testing')
38+
parser.add_argument('--data_dir', type=str, default="Sara_dataset/test", help='Dataset directory for loading series')
39+
40+
args = parser.parse_args()
41+
test(args)

train.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,22 @@ def train(args):
1414
dataset = SeriesDataset(args.data_dir)
1515

1616
train_size = int(0.8 * len(dataset))
17-
1817
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, len(dataset) - train_size])
1918

2019
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
2120
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
2221

23-
trainer = lightning.Trainer(default_root_dir=args.checkpoint_dir, max_epochs=100)
22+
trainer = lightning.Trainer(default_root_dir=args.checkpoint_dir, max_epochs=args.epochs)
2423
trainer.fit(model=model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
2524

25+
checkpoint = {
26+
'hyperparameters': vars(args),
27+
'state_dict': model.state_dict(),
28+
}
29+
torch.save(checkpoint, f'{args.checkpoint_dir}/model_checkpoint.pth')
30+
31+
trainer.test(model=model, dataloaders=val_dataloader)
32+
2633

2734
def main():
2835
parser = argparse.ArgumentParser(description='CNN-Transformer Time Series Classification')
@@ -32,12 +39,13 @@ def main():
3239
parser.add_argument('--num_layers', type=int, default=2, help='Number of transformer layers')
3340
parser.add_argument('--num_classes', type=int, default=5, help='Number of classes for classification')
3441
parser.add_argument('--batch_size', type=int, default=8, help='Batch size for training')
35-
3642
parser.add_argument('--data_dir', type=str, default="Sara_dataset/", help='Dataset directory for loading series')
3743
parser.add_argument('--checkpoint_dir', type=str, default="checkpoints/", help='Directory to save checkpoints')
3844
parser.add_argument('--resume_training', action='store_true', help='Resume training from checkpoint')
45+
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs to train for')
3946

40-
train(parser.parse_args())
47+
args = parser.parse_args()
48+
train(args)
4149

4250

4351
if __name__ == '__main__':

0 commit comments

Comments
 (0)