Skip to content

Commit db0e13d

Browse files
committed
Finished test metrics
1 parent f01b3a4 commit db0e13d

File tree

5 files changed

+67
-10
lines changed

5 files changed

+67
-10
lines changed

requirements.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,8 @@ urllib3=2.1.0=pypi_0
7878
wheel=0.41.3=pyhd8ed1ab_0
7979
xz=5.2.6=h166bdaf_0
8080
yarl=1.9.2=pypi_0
81+
lightning~=2.1.1
82+
torch~=2.1.0
83+
plotly~=5.18.0
84+
torchmetrics~=1.2.0
85+
pandas~=1.5.3

src/model.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ def __init__(self, input_size, hidden_size, num_heads, num_layers, num_classes,
99
super().__init__()
1010
self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
1111

12+
self.num_classes = num_classes
13+
1214
self.cnn = nn.Sequential(
1315
nn.Sequential(
1416
nn.Conv1d(input_size, hidden_size, kernel_size=3, padding=1),
@@ -75,8 +77,26 @@ def test_step(self, batch, batch_idx):
7577
pred = torch.argmax(logits, dim=1)
7678
y = torch.argmax(y, dim=1)
7779
acc = self.accuracy(pred, y)
78-
self.log('test_loss', loss, on_epoch=True, on_step=False)
79-
self.log('test_accuracy', acc, on_epoch=True, on_step=False)
80+
81+
cm = torchmetrics.functional.confusion_matrix(pred, y, task='multiclass', num_classes=self.num_classes)
82+
# Can't log tensors, and cm is multiclass, so have to log each class separately
83+
self.log('loss', loss, on_epoch=True, on_step=False)
84+
self.log('accuracy', acc, on_epoch=True, on_step=False)
85+
86+
for i in range(self.num_classes):
87+
false_positives = torch.sum(cm[:, i]) - cm[i, i]
88+
false_negatives = torch.sum(cm[i, :]) - cm[i, i]
89+
true_positives = cm[i, i]
90+
true_negatives = torch.sum(cm) - (false_positives + false_negatives + true_positives)
91+
92+
precision = true_positives / (true_positives + false_positives + 1e-8)
93+
recall = true_positives / (true_positives + false_negatives + 1e-8)
94+
f1 = 2 * (precision * recall) / (precision + recall + 1e-8)
95+
96+
self.log(f'precision_{i}', precision, on_epoch=True, on_step=False)
97+
self.log(f'recall_{i}', recall, on_epoch=True, on_step=False)
98+
self.log(f'f1_{i}', f1, on_epoch=True, on_step=False)
99+
80100
return loss
81101

82102
def configure_optimizers(self):

src/split.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import os
2+
import shutil
3+
import random
4+
5+
split_ratio = 0.2
6+
test_dir = "data/test"
7+
train_dir = "data/train"
8+
data_dir = "data/Sara_dataset"
9+
10+
os.makedirs(train_dir, exist_ok=True)
11+
os.makedirs(test_dir, exist_ok=True)
12+
13+
file_list = os.listdir(data_dir)
14+
15+
test_size = int(split_ratio * len(file_list))
16+
17+
random.shuffle(file_list)
18+
19+
for file_name in file_list[:test_size]:
20+
source_path = os.path.join(data_dir, file_name)
21+
target_path = os.path.join(test_dir, file_name)
22+
shutil.move(source_path, target_path)
23+
24+
for file_name in file_list[test_size:]:
25+
source_path = os.path.join(data_dir, file_name)
26+
target_path = os.path.join(train_dir, file_name)
27+
shutil.move(source_path, target_path)
28+
29+
# Remove data dir
30+
shutil.rmtree(data_dir)

src/test.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from src.dataset import SeriesDataset
66

77

8+
from plotly import express as px
9+
10+
811
def test(args):
912
checkpoint = torch.load(args.checkpoint_path)
1013
hyperparams = checkpoint['hyperparameters']
@@ -24,8 +27,9 @@ def test(args):
2427
dataset = SeriesDataset(args.data_dir)
2528
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)
2629

27-
trainer = lightning.Trainer(default_root_dir=args['checkpoint_dir'], max_epochs=args['epochs'])
28-
trainer.test(model=model, dataloaders=dataloader)
30+
trainer = lightning.Trainer(default_root_dir=hyperparams['checkpoint_dir'], max_epochs=hyperparams['max_epochs'])
31+
32+
metrics = trainer.test(model=model, dataloaders=dataloader)
2933

3034

3135
if __name__ == '__main__':
@@ -35,7 +39,7 @@ def test(args):
3539
parser.add_argument('--checkpoint_path', type=str, default="checkpoints/model_checkpoint.pth",
3640
help='Path to the model checkpoint file')
3741
parser.add_argument('--batch_size', type=int, default=8, help='Batch size for testing')
38-
parser.add_argument('--data_dir', type=str, default="Sara_dataset/test", help='Dataset directory for loading series')
42+
parser.add_argument('--data_dir', type=str, default="data/test", help='Dataset directory for loading series')
3943

4044
args = parser.parse_args()
4145
test(args)

src/train.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def train(args):
1919
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
2020
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
2121

22-
trainer = lightning.Trainer(default_root_dir=args.checkpoint_dir, max_epochs=args.epochs)
22+
trainer = lightning.Trainer(default_root_dir=args.checkpoint_dir, max_epochs=args.max_epochs)
2323
trainer.fit(model=model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
2424

2525
checkpoint = {
@@ -28,8 +28,6 @@ def train(args):
2828
}
2929
torch.save(checkpoint, f'{args.checkpoint_dir}/model_checkpoint.pth')
3030

31-
trainer.test(model=model, dataloaders=val_dataloader)
32-
3331

3432
def main():
3533
parser = argparse.ArgumentParser(description='CNN-Transformer Time Series Classification')
@@ -39,10 +37,10 @@ def main():
3937
parser.add_argument('--num_layers', type=int, default=2, help='Number of transformer layers')
4038
parser.add_argument('--num_classes', type=int, default=5, help='Number of classes for classification')
4139
parser.add_argument('--batch_size', type=int, default=8, help='Batch size for training')
42-
parser.add_argument('--data_dir', type=str, default="Sara_dataset/", help='Dataset directory for loading series')
40+
parser.add_argument('--data_dir', type=str, default="data/train", help='Dataset directory for loading series')
4341
parser.add_argument('--checkpoint_dir', type=str, default="checkpoints/", help='Directory to save checkpoints')
4442
parser.add_argument('--resume_training', action='store_true', help='Resume training from checkpoint')
45-
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs to train for')
43+
parser.add_argument('--max_epochs', type=int, default=100, help='Number of epochs to train for')
4644

4745
args = parser.parse_args()
4846
train(args)

0 commit comments

Comments
 (0)