Skip to content

Commit 01d5c1c

Browse files
committed
Initial commit
1 parent 8867c24 commit 01d5c1c

File tree

7 files changed

+205
-95
lines changed

7 files changed

+205
-95
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
checkpoints/
2+
lightning_logs/
3+
Sara_dataset/

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,8 @@
11
# ml-anomaly
22
Time Series Multi-class Anomaly detection using a Hybrid Transformer/CNN for physical sensor data readings. Provides an improvement on the existing dataset analysis paper Evaluating Conveyor Belt Health With Signal Processing Applied to Inertial Sensing
3+
4+
## Installation
5+
```bash
6+
pip install -r requirements.txt
7+
```
8+

dataset.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,28 @@
11
import os
2+
from enum import Enum
3+
24
import pandas as pd
5+
import torch
36
from torch.utils.data import Dataset
47

58

9+
class Labels(Enum):
10+
NORMAL = 0
11+
ANOMALY_1 = 1
12+
ANOMALY_2 = 2
13+
ANOMALY_3 = 3
14+
ANOMALY_4 = 4
15+
16+
617
class SeriesDataset(Dataset):
18+
label_map = {
19+
"Situação referencia index": 0,
20+
"Situação rolo 3 da esquerda levantado index": 1,
21+
"Situação rolo 3 da esquerda removido index": 2,
22+
"Situação rolo 3 da direita levantado index": 3,
23+
"Situação rolo 3 da direita removido index": 4,
24+
}
25+
726
def __init__(self, data_dir):
827
self.data_dir = data_dir
928
self.file_list = [f for f in os.listdir(data_dir) if f.endswith('.csv')]
@@ -12,9 +31,13 @@ def __len__(self):
1231
return len(self.file_list)
1332

1433
def __getitem__(self, idx):
34+
# filename = <LABEL> <ID>.csv
1535
file_path = os.path.join(self.data_dir, self.file_list[idx])
16-
data = pd.read_csv(file_path).values # Assuming CSV files have numeric values
17-
x = torch.FloatTensor(data[:, :-1]) # Input features
18-
y = torch.LongTensor(data[:, -1]) # Labels
19-
return x, y
2036

37+
label = " ".join(self.file_list[idx].split('.')[0].split(' ')[:-1])
38+
y = torch.zeros(5).scatter_(0, torch.LongTensor([self.label_map[label]]), 1)
39+
40+
data = pd.read_csv(file_path).values
41+
x = torch.FloatTensor(data)
42+
43+
return x, y

model.py

Lines changed: 65 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,81 @@
11
import torch
2-
import torch.nn as nn
2+
import torchmetrics
3+
from torch import nn
4+
import lightning
35

4-
class MultiClassAnomaly(nn.Module):
5-
def __init__(self, input_size, hidden_size, num_heads, num_layers, num_classes, dropout_rate=0.1):
6-
super(MultiClassAnomaly, self).__init__()
6+
7+
class MultiClassAnomaly(lightning.LightningModule):
8+
def __init__(self, input_size, hidden_size, num_heads, num_layers, num_classes, dropout_rate=0.5):
9+
super().__init__()
10+
self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
711

812
self.cnn = nn.Sequential(
9-
nn.Conv1d(in_channels=input_size, out_channels=hidden_size, kernel_size=3, padding=1),
10-
nn.ReLU(),
11-
nn.MaxPool1d(kernel_size=2)
13+
nn.Sequential(
14+
nn.Conv1d(input_size, hidden_size, kernel_size=3, padding=1),
15+
nn.BatchNorm1d(hidden_size),
16+
nn.ReLU(),
17+
nn.MaxPool1d(2),
18+
),
19+
nn.Sequential(
20+
nn.Conv1d(hidden_size, hidden_size, kernel_size=3, padding=1),
21+
nn.BatchNorm1d(hidden_size),
22+
nn.ReLU(),
23+
nn.MaxPool1d(2),
24+
)
1225
)
1326

14-
self.embedding = nn.Linear(hidden_size, hidden_size)
15-
self.transformer_encoder = nn.TransformerEncoder(
16-
nn.TransformerEncoderLayer(d_model=hidden_size, nhead=num_heads),
27+
self.transformer = nn.TransformerEncoder(
28+
nn.TransformerEncoderLayer(
29+
d_model=hidden_size,
30+
nhead=num_heads,
31+
dim_feedforward=hidden_size,
32+
dropout=dropout_rate,
33+
activation='relu'
34+
),
1735
num_layers=num_layers
1836
)
1937

20-
self.fc = nn.Linear(hidden_size, num_classes)
21-
22-
self.dropout = nn.Dropout(p=dropout_rate)
38+
self.classifier = nn.Sequential(
39+
nn.Linear(hidden_size, hidden_size),
40+
nn.ReLU(),
41+
nn.Dropout(dropout_rate),
42+
nn.Linear(hidden_size, num_classes)
43+
)
2344

2445
def forward(self, x):
25-
x = self.cnn(x)
2646
x = x.permute(0, 2, 1)
27-
x = self.embedding(x)
47+
x = self.cnn(x)
48+
x = x.permute(2, 0, 1)
49+
x = self.transformer(x)
50+
x = x.permute(1, 0, 2)
51+
x = x[:, -1, :]
52+
x = self.classifier(x)
53+
return x
2854

29-
x = self.transformer_encoder(x)
30-
x = x.mean(dim=1)
55+
def training_step(self, batch, batch_idx):
56+
x, y = batch
57+
logits = self.forward(x)
58+
loss = nn.functional.cross_entropy(logits, y)
59+
self.log('train_loss', loss, on_epoch=True, on_step=True)
60+
return loss
3161

32-
x = self.dropout(x)
33-
output = self.fc(x)
62+
def validation_step(self, batch, batch_idx):
63+
x, y = batch
64+
logits = self.forward(x)
65+
loss = nn.functional.cross_entropy(logits, y)
66+
acc = self.accuracy(logits, y)
67+
self.log('val_loss', loss, on_epoch=True, on_step=False)
68+
self.log('val_accuracy', acc, on_epoch=True, on_step=False)
69+
return loss
3470

35-
return output
71+
def test_step(self, batch, batch_idx):
72+
x, y = batch
73+
logits = self.forward(x)
74+
loss = nn.functional.cross_entropy(logits, y)
75+
acc = self.accuracy(logits, y)
76+
self.log('test_loss', loss, on_epoch=True, on_step=False)
77+
self.log('test_accuracy', acc, on_epoch=True, on_step=False)
78+
return loss
3679

80+
def configure_optimizers(self):
81+
return torch.optim.Adam(self.parameters(), lr=0.001)

requirements.txt

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# This file may be used to create an environment using:
2+
# $ conda create --name <env> --file <this file>
3+
# platform: linux-64
4+
_libgcc_mutex=0.1=conda_forge
5+
_openmp_mutex=4.5=2_gnu
6+
aiohttp=3.8.6=pypi_0
7+
aiosignal=1.3.1=pypi_0
8+
async-timeout=4.0.3=pypi_0
9+
attrs=23.1.0=pypi_0
10+
bzip2=1.0.8=hd590300_5
11+
ca-certificates=2023.7.22=hbcca054_0
12+
certifi=2023.7.22=pypi_0
13+
charset-normalizer=3.3.2=pypi_0
14+
filelock=3.13.1=pypi_0
15+
frozenlist=1.4.0=pypi_0
16+
fsspec=2023.10.0=pypi_0
17+
idna=3.4=pypi_0
18+
jinja2=3.1.2=pypi_0
19+
ld_impl_linux-64=2.40=h41732ed_0
20+
libblas=3.9.0=19_linux64_openblas
21+
libcblas=3.9.0=19_linux64_openblas
22+
libffi=3.4.2=h7f98852_5
23+
libgcc-ng=13.2.0=h807b86a_3
24+
libgfortran-ng=13.2.0=h69a702a_3
25+
libgfortran5=13.2.0=ha4646dd_3
26+
libgomp=13.2.0=h807b86a_3
27+
liblapack=3.9.0=19_linux64_openblas
28+
libnsl=2.0.1=hd590300_0
29+
libopenblas=0.3.24=pthreads_h413a1c8_0
30+
libsqlite=3.44.0=h2797004_0
31+
libstdcxx-ng=13.2.0=h7e041cc_3
32+
libuuid=2.38.1=h0b41bf4_0
33+
libzlib=1.2.13=hd590300_5
34+
lightning=2.1.1=pypi_0
35+
lightning-utilities=0.9.0=pypi_0
36+
markupsafe=2.1.3=pypi_0
37+
mpmath=1.3.0=pypi_0
38+
multidict=6.0.4=pypi_0
39+
ncurses=6.4=h59595ed_2
40+
networkx=3.2.1=pypi_0
41+
numpy=1.26.2=pypi_0
42+
nvidia-cublas-cu12=12.1.3.1=pypi_0
43+
nvidia-cuda-cupti-cu12=12.1.105=pypi_0
44+
nvidia-cuda-nvrtc-cu12=12.1.105=pypi_0
45+
nvidia-cuda-runtime-cu12=12.1.105=pypi_0
46+
nvidia-cudnn-cu12=8.9.2.26=pypi_0
47+
nvidia-cufft-cu12=11.0.2.54=pypi_0
48+
nvidia-curand-cu12=10.3.2.106=pypi_0
49+
nvidia-cusolver-cu12=11.4.5.107=pypi_0
50+
nvidia-cusparse-cu12=12.1.0.106=pypi_0
51+
nvidia-nccl-cu12=2.18.1=pypi_0
52+
nvidia-nvjitlink-cu12=12.3.52=pypi_0
53+
nvidia-nvtx-cu12=12.1.105=pypi_0
54+
openssl=3.1.4=hd590300_0
55+
packaging=23.2=pypi_0
56+
pandas=2.1.3=py310hcc13569_0
57+
pip=23.3.1=pyhd8ed1ab_0
58+
python=3.10.13=hd12c33a_0_cpython
59+
python-dateutil=2.8.2=pyhd8ed1ab_0
60+
python-tzdata=2023.3=pyhd8ed1ab_0
61+
python_abi=3.10=4_cp310
62+
pytorch-lightning=2.1.1=pypi_0
63+
pytz=2023.3.post1=pyhd8ed1ab_0
64+
pyyaml=6.0.1=pypi_0
65+
readline=8.2=h8228510_1
66+
requests=2.31.0=pypi_0
67+
setuptools=68.2.2=pyhd8ed1ab_0
68+
six=1.16.0=pyh6c4a22f_0
69+
sympy=1.12=pypi_0
70+
tk=8.6.13=noxft_h4845f30_101
71+
torch=2.1.0=pypi_0
72+
torchmetrics=1.2.0=pypi_0
73+
tqdm=4.66.1=pypi_0
74+
triton=2.1.0=pypi_0
75+
typing-extensions=4.8.0=pypi_0
76+
tzdata=2023c=h71feb2d_0
77+
urllib3=2.1.0=pypi_0
78+
wheel=0.41.3=pyhd8ed1ab_0
79+
xz=5.2.6=h166bdaf_0
80+
yarl=1.9.2=pypi_0

train.py

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,44 @@
1+
import argparse
12

3+
import torch
4+
from torch.utils.data import DataLoader
5+
import lightning
26

3-
def train_model(args):
4-
model = CNNTransformerModel(args.input_size, args.hidden_size, args.num_heads, args.num_layers, args.num_classes)
7+
from model import MultiClassAnomaly
8+
from dataset import SeriesDataset
59

6-
model_checkpoint = pl.callbacks.ModelCheckpoint(
7-
dirpath=args.checkpoint_dir,
8-
filename='best_model',
9-
monitor='val_loss',
10-
mode='min',
11-
save_top_k=1
12-
)
1310

14-
train_dataset = TimeSeriesDataset(args.data_dir)
15-
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
11+
def train(args):
12+
model = MultiClassAnomaly(args.input_size, args.hidden_size, args.num_heads, args.num_layers, args.num_classes)
1613

17-
val_dataset = TimeSeriesDataset(args.val_data_dir)
18-
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
14+
dataset = SeriesDataset(args.data_dir)
1915

20-
trainer = CNNTransformerTrainer(
21-
model,
22-
train_loader=train_loader,
23-
val_loader=val_loader,
24-
test_loader=test_loader,
25-
model_checkpoint=model_checkpoint
26-
)
16+
train_size = int(0.8 * len(dataset))
2717

28-
if args.resume_training:
29-
trainer.load_checkpoint(args.resume_checkpoint)
18+
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, len(dataset) - train_size])
3019

31-
trainer.fit()
20+
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
21+
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
3222

33-
path = os.path.join(args.checkpoint_dir, 'final_model.pth')
34-
trainer.save_checkpoint(path)
23+
trainer = lightning.Trainer(default_root_dir=args.checkpoint_dir, max_epochs=100)
24+
trainer.fit(model=model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
3525

3626

3727
def main():
3828
parser = argparse.ArgumentParser(description='CNN-Transformer Time Series Classification')
39-
parser.add_argument('--input_size', type=int, default=6, help='Number of input features')
29+
parser.add_argument('--input_size', type=int, default=7, help='Number of input features')
4030
parser.add_argument('--hidden_size', type=int, default=64, help='Hidden size for the model')
4131
parser.add_argument('--num_heads', type=int, default=4, help='Number of attention heads')
4232
parser.add_argument('--num_layers', type=int, default=2, help='Number of transformer layers')
4333
parser.add_argument('--num_classes', type=int, default=5, help='Number of classes for classification')
4434
parser.add_argument('--batch_size', type=int, default=8, help='Batch size for training')
45-
parser.add_argument('--data_dir', type=str, default, help='Dataset directory from which to load series')
4635

47-
train(parser.args)
36+
parser.add_argument('--data_dir', type=str, default="Sara_dataset/", help='Dataset directory for loading series')
37+
parser.add_argument('--checkpoint_dir', type=str, default="checkpoints/", help='Directory to save checkpoints')
38+
parser.add_argument('--resume_training', action='store_true', help='Resume training from checkpoint')
4839

40+
train(parser.parse_args())
41+
42+
43+
if __name__ == '__main__':
44+
main()

trainer.py

Lines changed: 0 additions & 43 deletions
This file was deleted.

0 commit comments

Comments
 (0)