-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmain.py
79 lines (69 loc) · 2.73 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import argparse
import os
from puad.dataset import build_dataset
from puad.efficientad.inference import load_efficient_ad
from puad.puad import PUAD
import torch
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="PUAD")
parser.add_argument(
"dataset_path",
type=str,
help="Path to dataset directory containing `train` and `test` (and `validation` in MVTec LOCO AD Dataset)",
)
parser.add_argument(
"model_dir_path",
type=str,
help="Path to directory containing pretrained models",
)
parser.add_argument(
"--size",
choices=["s", "m"],
type=str,
default="s",
help=(
"Specify the size of EfficientAD used for Picturable anomaly detection "
"and feature extraction for Unpicturable anomaly detection in either `s` or `m`"
),
)
parser.add_argument(
"--feature_extractor",
choices=["student", "teacher"],
type=str,
default="student",
help=(
"Specify the network in EfficientAD used for feature extraction for Unpicturable anomaly detection "
"in either `teacher` or `student`"
),
)
args = parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
dataset_dir, category = os.path.split(os.path.abspath(args.dataset_path))
dataset_name = os.path.split(dataset_dir)[1]
if not (
os.path.exists(os.path.join(args.dataset_path, "train"))
and os.path.exists(os.path.join(args.dataset_path, "test"))
):
raise ValueError("The dataset specified in `dataset_path` must contain `train` and `test` directories.")
print(f"dataset name : {dataset_name}")
print(f"category : {category}")
print(f"size : {args.size}")
print(f"feature extractor : {args.feature_extractor}")
# load EfficientAD
efficient_ad_inference = load_efficient_ad(args.model_dir_path, args.size, dataset_name, category)
# build dataset
train_dataset, valid_dataset, test_dataset = build_dataset(args.dataset_path)
# EfficientAD
efficient_ad_auroc = efficient_ad_inference.auroc(test_dataset)
print(f"efficient_ad auroc : {efficient_ad_auroc}")
# PUAD
puad = PUAD(feature_extractor=args.feature_extractor)
puad.load_efficient_ad(efficient_ad_inference)
puad.train(train_dataset)
puad.valid(valid_dataset)
puad_auroc, puad_auroc_for_anomalies = puad.auroc_for_anomalies(test_dataset)
print(f"puad auroc : {puad_auroc}")
for anomaly_class, auroc_for_anomaly in puad_auroc_for_anomalies.items():
print(f"puad auroc for {anomaly_class}: {auroc_for_anomaly}")