Skip to content

Commit c777959

Browse files
committed
Updated project structure, added Dockerfile
1 parent db0e13d commit c777959

File tree

6 files changed

+29
-100
lines changed

6 files changed

+29
-100
lines changed

Dockerfile

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Use the official Python image as a base image
2+
FROM python:3.10-slim
3+
4+
WORKDIR /app
5+
6+
# Install any needed packages specified in requirements.txt
7+
COPY requirements.txt .
8+
RUN pip install -r requirements.txt
9+
10+
COPY checkpoints/model_checkpoint.pth checkpoints/
11+
COPY data/test data/test
12+
13+
COPY src/ src/
14+
15+
RUN pwd && ls -la
16+
CMD ["python", "src/test.py"]

requirements.txt

Lines changed: 0 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,83 +1,3 @@
1-
# This file may be used to create an environment using:
2-
# $ conda create --name <env> --file <this file>
3-
# platform: linux-64
4-
_libgcc_mutex=0.1=conda_forge
5-
_openmp_mutex=4.5=2_gnu
6-
aiohttp=3.8.6=pypi_0
7-
aiosignal=1.3.1=pypi_0
8-
async-timeout=4.0.3=pypi_0
9-
attrs=23.1.0=pypi_0
10-
bzip2=1.0.8=hd590300_5
11-
ca-certificates=2023.7.22=hbcca054_0
12-
certifi=2023.7.22=pypi_0
13-
charset-normalizer=3.3.2=pypi_0
14-
filelock=3.13.1=pypi_0
15-
frozenlist=1.4.0=pypi_0
16-
fsspec=2023.10.0=pypi_0
17-
idna=3.4=pypi_0
18-
jinja2=3.1.2=pypi_0
19-
ld_impl_linux-64=2.40=h41732ed_0
20-
libblas=3.9.0=19_linux64_openblas
21-
libcblas=3.9.0=19_linux64_openblas
22-
libffi=3.4.2=h7f98852_5
23-
libgcc-ng=13.2.0=h807b86a_3
24-
libgfortran-ng=13.2.0=h69a702a_3
25-
libgfortran5=13.2.0=ha4646dd_3
26-
libgomp=13.2.0=h807b86a_3
27-
liblapack=3.9.0=19_linux64_openblas
28-
libnsl=2.0.1=hd590300_0
29-
libopenblas=0.3.24=pthreads_h413a1c8_0
30-
libsqlite=3.44.0=h2797004_0
31-
libstdcxx-ng=13.2.0=h7e041cc_3
32-
libuuid=2.38.1=h0b41bf4_0
33-
libzlib=1.2.13=hd590300_5
34-
lightning=2.1.1=pypi_0
35-
lightning-utilities=0.9.0=pypi_0
36-
markupsafe=2.1.3=pypi_0
37-
mpmath=1.3.0=pypi_0
38-
multidict=6.0.4=pypi_0
39-
ncurses=6.4=h59595ed_2
40-
networkx=3.2.1=pypi_0
41-
numpy=1.26.2=pypi_0
42-
nvidia-cublas-cu12=12.1.3.1=pypi_0
43-
nvidia-cuda-cupti-cu12=12.1.105=pypi_0
44-
nvidia-cuda-nvrtc-cu12=12.1.105=pypi_0
45-
nvidia-cuda-runtime-cu12=12.1.105=pypi_0
46-
nvidia-cudnn-cu12=8.9.2.26=pypi_0
47-
nvidia-cufft-cu12=11.0.2.54=pypi_0
48-
nvidia-curand-cu12=10.3.2.106=pypi_0
49-
nvidia-cusolver-cu12=11.4.5.107=pypi_0
50-
nvidia-cusparse-cu12=12.1.0.106=pypi_0
51-
nvidia-nccl-cu12=2.18.1=pypi_0
52-
nvidia-nvjitlink-cu12=12.3.52=pypi_0
53-
nvidia-nvtx-cu12=12.1.105=pypi_0
54-
openssl=3.1.4=hd590300_0
55-
packaging=23.2=pypi_0
56-
pandas=2.1.3=py310hcc13569_0
57-
pip=23.3.1=pyhd8ed1ab_0
58-
python=3.10.13=hd12c33a_0_cpython
59-
python-dateutil=2.8.2=pyhd8ed1ab_0
60-
python-tzdata=2023.3=pyhd8ed1ab_0
61-
python_abi=3.10=4_cp310
62-
pytorch-lightning=2.1.1=pypi_0
63-
pytz=2023.3.post1=pyhd8ed1ab_0
64-
pyyaml=6.0.1=pypi_0
65-
readline=8.2=h8228510_1
66-
requests=2.31.0=pypi_0
67-
setuptools=68.2.2=pyhd8ed1ab_0
68-
six=1.16.0=pyh6c4a22f_0
69-
sympy=1.12=pypi_0
70-
tk=8.6.13=noxft_h4845f30_101
71-
torch=2.1.0=pypi_0
72-
torchmetrics=1.2.0=pypi_0
73-
tqdm=4.66.1=pypi_0
74-
triton=2.1.0=pypi_0
75-
typing-extensions=4.8.0=pypi_0
76-
tzdata=2023c=h71feb2d_0
77-
urllib3=2.1.0=pypi_0
78-
wheel=0.41.3=pyhd8ed1ab_0
79-
xz=5.2.6=h166bdaf_0
80-
yarl=1.9.2=pypi_0
811
lightning~=2.1.1
822
torch~=2.1.0
833
plotly~=5.18.0

src/__init__.py

Whitespace-only changes.

src/model.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ def __init__(self, input_size, hidden_size, num_heads, num_layers, num_classes,
1010
self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
1111

1212
self.num_classes = num_classes
13+
self.confusion_matrix = torch.zeros(num_classes, num_classes)
1314

1415
self.cnn = nn.Sequential(
1516
nn.Sequential(
@@ -78,24 +79,12 @@ def test_step(self, batch, batch_idx):
7879
y = torch.argmax(y, dim=1)
7980
acc = self.accuracy(pred, y)
8081

81-
cm = torchmetrics.functional.confusion_matrix(pred, y, task='multiclass', num_classes=self.num_classes)
82-
# Can't log tensors, and cm is multiclass, so have to log each class separately
8382
self.log('loss', loss, on_epoch=True, on_step=False)
8483
self.log('accuracy', acc, on_epoch=True, on_step=False)
8584

86-
for i in range(self.num_classes):
87-
false_positives = torch.sum(cm[:, i]) - cm[i, i]
88-
false_negatives = torch.sum(cm[i, :]) - cm[i, i]
89-
true_positives = cm[i, i]
90-
true_negatives = torch.sum(cm) - (false_positives + false_negatives + true_positives)
91-
92-
precision = true_positives / (true_positives + false_positives + 1e-8)
93-
recall = true_positives / (true_positives + false_negatives + 1e-8)
94-
f1 = 2 * (precision * recall) / (precision + recall + 1e-8)
95-
96-
self.log(f'precision_{i}', precision, on_epoch=True, on_step=False)
97-
self.log(f'recall_{i}', recall, on_epoch=True, on_step=False)
98-
self.log(f'f1_{i}', f1, on_epoch=True, on_step=False)
85+
self.confusion_matrix = (self.confusion_matrix +
86+
torchmetrics.functional.confusion_matrix(pred.cpu(), y.cpu(), 'multiclass',
87+
num_classes=self.num_classes))
9988

10089
return loss
10190

src/test.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import lightning
22
import torch
33
from torch.utils.data import DataLoader
4-
from src.model import MultiClassAnomaly
5-
from src.dataset import SeriesDataset
4+
from model import MultiClassAnomaly
5+
from dataset import SeriesDataset
66

77

88
from plotly import express as px
@@ -29,7 +29,11 @@ def test(args):
2929

3030
trainer = lightning.Trainer(default_root_dir=hyperparams['checkpoint_dir'], max_epochs=hyperparams['max_epochs'])
3131

32-
metrics = trainer.test(model=model, dataloaders=dataloader)
32+
trainer.test(model=model, dataloaders=dataloader)
33+
34+
confusion_matrix = model.confusion_matrix
35+
heatmap = px.imshow(confusion_matrix, labels=dict(x="Predicted", y="Actual", color="Count"), text_auto=True)
36+
heatmap.show()
3337

3438

3539
if __name__ == '__main__':

src/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from torch.utils.data import DataLoader
55
import lightning
66

7-
from src.model import MultiClassAnomaly
8-
from src.dataset import SeriesDataset
7+
from model import MultiClassAnomaly
8+
from dataset import SeriesDataset
99

1010

1111
def train(args):

0 commit comments

Comments
 (0)