Skip to content

Commit bd8151a

Browse files
authored
The initial code
1 parent 6d89263 commit bd8151a

21 files changed

+4800
-2
lines changed

License_M3DM.txt

+1,165
Large diffs are not rendered by default.

README.md

+114-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,114 @@
1-
# AnomalyDetection-M3DM
2-
Code for CVPR 2023 paper "Multimodal Industrial Anomaly Detection via Hybrid Fusion"
1+
# Multimodal Industrial Anomaly Detection via Hybrid Fusion
2+
3+
![piplien](figures/pipeline.png)
4+
- `The pipeline of Multi-3D-Memory (M3DM).` Our M3DM contains three important parts: (1) **Point Feature Alignment** (PFA) converts Point Group features to plane features with interpolation and project operation, $\text{FPS}$ is the farthest point sampling and $\mathcal F_{pt}$ is a pretrained Point Transformer; (2) **Unsupervised Feature Fusion** (UFF) fuses point feature and image feature together with a patch-wise contrastive loss $\mathcal L_{con}$, where $\mathcal F_{rgb}$ is a Vision Transformer, $\chi_{rgb},\chi_{pt}$ are MLP layers and $\sigma_r, \sigma_p$ are single fully connected layers; (3) **Decision Layer Fusion** (DLF) combines multimodal information with multiple memory banks and makes the final decision with 2 learnable modules $\mathcal D_a, \mathcal D_s$ for anomaly detection and segmentation, where $\mathcal{M}_{rgb}, \mathcal{M}_{fs}, \mathcal{M}_{pt}$ are memory banks, $\phi, \psi$ are score function for single memory bank detection and segmentation, and $\mathcal{P}$ is the memory bank building algorithm.
5+
6+
## Setup
7+
8+
We implement this repo with the following environment:
9+
- Python 3.8
10+
- Pytorch 1.9.0
11+
- CUDA 11.3
12+
13+
Install the other package via:
14+
15+
``` bash
16+
pip install -r requirement.txt
17+
# install knn_cuda
18+
pip install --upgrade https://github.com/unlimblue/KNN_CUDA/releases/download/0.2/KNN_CUDA-0.2-py3-none-any.whl
19+
# install pointnet2_ops_lib
20+
pip install "git+git://github.com/erikwijmans/Pointnet2_PyTorch.git#egg=pointnet2_ops&subdirectory=pointnet2_ops_lib"
21+
```
22+
23+
## Data Download and Preprocess
24+
25+
### Dataset
26+
27+
- The `MVTec-3D AD` dataset can be download from the [Official Website of MVTec-3D AD](https://www.mvtec.com/company/research/datasets/mvtec-3d-ad).
28+
29+
- The `Eyecandies` dataset can be download from the [Official Website of Eyecandies](https://eyecan-ai.github.io/eyecandies/).
30+
31+
After download, put the dataset in `dataset` folder.
32+
33+
### Datapreprocess
34+
35+
36+
To run the preprocessing
37+
```bash
38+
python utils/preprocessing.py datasets/mvtec3d/
39+
```
40+
41+
It may take a few hours to run the preprocessing.
42+
43+
### Checkpoints
44+
45+
The following table lists the pretrain model used in M3DM:
46+
47+
| Backbone | Pretrain Method |
48+
| ----------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
49+
| Point Transformer | [Point-MAE](https://github.com/Pang-Yatian/Point-MAE/releases/download/main/pretrain.pth) |
50+
| Point Transformer | [Point-Bert](https://cloud.tsinghua.edu.cn/f/202b29805eea45d7be92/?dl=1) |
51+
| ViT-b/8 | [DINO](https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth) |
52+
| ViT-b/8 | [Supervised ImageNet 1K](https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz) |
53+
| ViT-b/8 | [Supervised ImageNet 21K](https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz) |
54+
| ViT-s/8 | [DINO](https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth) |
55+
56+
Put the checkpoint files in `checkpoints` folder.
57+
58+
## Train and Test
59+
60+
Train and test the double lib version and save the feature for UFF training:
61+
62+
```bash
63+
python3 main.py \
64+
--method_name DINO+Point_MAE \
65+
--memory_bank multiple \
66+
--rgb_backbone_name vit_base_patch8_224_dino \
67+
--xyz_backbone_name Point_MAE \
68+
--save_feature True \
69+
```
70+
71+
Train the UFF:
72+
73+
```bash
74+
OMP_NUM_THREADS=1 python3 -m torch.distributed.launch --nproc_per_node=1 fusion_pretrain.py \
75+
--accum_iter 16 \
76+
--lr 0.003 \
77+
--batch_size 16 \
78+
--data_path datasets/patch_lib \
79+
--output_dir checkpoints \
80+
```
81+
82+
Train and test the full setting with the following command:
83+
84+
```bash
85+
python3 main.py \
86+
--method_name DINO+Point_MAE+Fusion \
87+
--use_uff \
88+
--memory_bank multiple \
89+
--rgb_backbone_name vit_base_patch8_224_dino \
90+
--xyz_backbone_name Point_MAE \
91+
--fusion_module_path checkpoints/{FUSION_CHECKPOINT}.pth \
92+
```
93+
94+
Note: if you set `--method_name DINO` or `--method_name Point_MAE`, set `--memory_bank single` at the same time.
95+
96+
97+
98+
If you find this repository useful for your research, please use the following.
99+
100+
```bibtex
101+
102+
@misc{wang2023multimodal,
103+
title={Multimodal Industrial Anomaly Detection via Hybrid Fusion},
104+
author={Wang, Yue and Peng, Jinlong and Zhang, Jiangning and Yi, Ran and Wang, Yabiao and Wang, Chengjie},
105+
year={2023},
106+
eprint={2303.00601},
107+
archivePrefix={arXiv},
108+
primaryClass={cs.CV}
109+
}
110+
```
111+
112+
## Thanks
113+
114+
Our repo is built on [3D-ADS](https://github.com/eliahuhorwitz/3D-ADS) and [MoCo-v3](https://github.com/facebookresearch/moco-v3), thanks their extraordinary works!

dataset.py

+188
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
import os
2+
from PIL import Image
3+
from torchvision import transforms
4+
import glob
5+
from torch.utils.data import Dataset
6+
from utils.mvtec3d_util import *
7+
from torch.utils.data import DataLoader
8+
import numpy as np
9+
10+
def eyecandies_classes():
11+
return [
12+
'CandyCane',
13+
'ChocolateCookie',
14+
'ChocolatePraline',
15+
'Confetto',
16+
'GummyBear',
17+
'HazelnutTruffle',
18+
'LicoriceSandwich',
19+
'Lollipop',
20+
'Marshmallow',
21+
'PeppermintCandy',
22+
]
23+
24+
def mvtec3d_classes():
25+
return [
26+
"bagel",
27+
"cable_gland",
28+
"carrot",
29+
"cookie",
30+
"dowel",
31+
"foam",
32+
"peach",
33+
"potato",
34+
"rope",
35+
"tire",
36+
]
37+
38+
RGB_SIZE = 224
39+
40+
class BaseAnomalyDetectionDataset(Dataset):
41+
42+
def __init__(self, split, class_name, img_size, dataset_path='datasets/eyecandies_preprocessed'):
43+
self.IMAGENET_MEAN = [0.485, 0.456, 0.406]
44+
self.IMAGENET_STD = [0.229, 0.224, 0.225]
45+
self.cls = class_name
46+
self.size = img_size
47+
self.img_path = os.path.join(dataset_path, self.cls, split)
48+
self.rgb_transform = transforms.Compose(
49+
[transforms.Resize((RGB_SIZE, RGB_SIZE), interpolation=transforms.InterpolationMode.BICUBIC),
50+
transforms.ToTensor(),
51+
transforms.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)])
52+
53+
class PreTrainTensorDataset(Dataset):
54+
def __init__(self, root_path):
55+
super().__init__()
56+
self.root_path = root_path
57+
self.tensor_paths = os.listdir(self.root_path)
58+
59+
60+
def __len__(self):
61+
return len(self.tensor_paths)
62+
63+
def __getitem__(self, idx):
64+
tensor_path = self.tensor_paths[idx]
65+
66+
tensor = torch.load(os.path.join(self.root_path, tensor_path))
67+
68+
label = 0
69+
70+
return tensor, label
71+
72+
class TrainDataset(BaseAnomalyDetectionDataset):
73+
def __init__(self, class_name, img_size, dataset_path='datasets/eyecandies_preprocessed'):
74+
super().__init__(split="train", class_name=class_name, img_size=img_size, dataset_path=dataset_path)
75+
self.img_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1
76+
77+
def load_dataset(self):
78+
img_tot_paths = []
79+
tot_labels = []
80+
rgb_paths = glob.glob(os.path.join(self.img_path, 'good', 'rgb') + "/*.png")
81+
tiff_paths = glob.glob(os.path.join(self.img_path, 'good', 'xyz') + "/*.tiff")
82+
rgb_paths.sort()
83+
tiff_paths.sort()
84+
sample_paths = list(zip(rgb_paths, tiff_paths))
85+
img_tot_paths.extend(sample_paths)
86+
tot_labels.extend([0] * len(sample_paths))
87+
return img_tot_paths, tot_labels
88+
89+
def __len__(self):
90+
return len(self.img_paths)
91+
92+
def __getitem__(self, idx):
93+
img_path, label = self.img_paths[idx], self.labels[idx]
94+
rgb_path = img_path[0]
95+
tiff_path = img_path[1]
96+
img = Image.open(rgb_path).convert('RGB')
97+
98+
img = self.rgb_transform(img)
99+
organized_pc = read_tiff_organized_pc(tiff_path)
100+
101+
depth_map_3channel = np.repeat(organized_pc_to_depth_map(organized_pc)[:, :, np.newaxis], 3, axis=2)
102+
resized_depth_map_3channel = resize_organized_pc(depth_map_3channel)
103+
resized_organized_pc = resize_organized_pc(organized_pc, target_height=self.size, target_width=self.size)
104+
resized_organized_pc = resized_organized_pc.clone().detach().float()
105+
106+
return (img, resized_organized_pc, resized_depth_map_3channel), label
107+
108+
109+
class TestDataset(BaseAnomalyDetectionDataset):
110+
def __init__(self, class_name, img_size, dataset_path='datasets/eyecandies_preprocessed'):
111+
super().__init__(split="test", class_name=class_name, img_size=img_size, dataset_path=dataset_path)
112+
self.gt_transform = transforms.Compose([
113+
transforms.Resize((RGB_SIZE, RGB_SIZE), interpolation=transforms.InterpolationMode.NEAREST),
114+
transforms.ToTensor()])
115+
self.img_paths, self.gt_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1
116+
117+
def load_dataset(self):
118+
img_tot_paths = []
119+
gt_tot_paths = []
120+
tot_labels = []
121+
defect_types = os.listdir(self.img_path)
122+
123+
for defect_type in defect_types:
124+
if defect_type == 'good':
125+
rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb') + "/*.png")
126+
tiff_paths = glob.glob(os.path.join(self.img_path, defect_type, 'xyz') + "/*.tiff")
127+
rgb_paths.sort()
128+
tiff_paths.sort()
129+
sample_paths = list(zip(rgb_paths, tiff_paths))
130+
img_tot_paths.extend(sample_paths)
131+
gt_tot_paths.extend([0] * len(sample_paths))
132+
tot_labels.extend([0] * len(sample_paths))
133+
else:
134+
rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb') + "/*.png")
135+
tiff_paths = glob.glob(os.path.join(self.img_path, defect_type, 'xyz') + "/*.tiff")
136+
gt_paths = glob.glob(os.path.join(self.img_path, defect_type, 'gt') + "/*.png")
137+
rgb_paths.sort()
138+
tiff_paths.sort()
139+
gt_paths.sort()
140+
sample_paths = list(zip(rgb_paths, tiff_paths))
141+
142+
img_tot_paths.extend(sample_paths)
143+
gt_tot_paths.extend(gt_paths)
144+
tot_labels.extend([1] * len(sample_paths))
145+
146+
assert len(img_tot_paths) == len(gt_tot_paths), "Something wrong with test and ground truth pair!"
147+
148+
return img_tot_paths, gt_tot_paths, tot_labels
149+
150+
def __len__(self):
151+
return len(self.img_paths)
152+
153+
def __getitem__(self, idx):
154+
img_path, gt, label = self.img_paths[idx], self.gt_paths[idx], self.labels[idx]
155+
rgb_path = img_path[0]
156+
tiff_path = img_path[1]
157+
img_original = Image.open(rgb_path).convert('RGB')
158+
img = self.rgb_transform(img_original)
159+
160+
organized_pc = read_tiff_organized_pc(tiff_path)
161+
depth_map_3channel = np.repeat(organized_pc_to_depth_map(organized_pc)[:, :, np.newaxis], 3, axis=2)
162+
resized_depth_map_3channel = resize_organized_pc(depth_map_3channel)
163+
resized_organized_pc = resize_organized_pc(organized_pc, target_height=self.size, target_width=self.size)
164+
resized_organized_pc = resized_organized_pc.clone().detach().float()
165+
166+
167+
168+
169+
if gt == 0:
170+
gt = torch.zeros(
171+
[1, resized_depth_map_3channel.size()[-2], resized_depth_map_3channel.size()[-2]])
172+
else:
173+
gt = Image.open(gt).convert('L')
174+
gt = self.gt_transform(gt)
175+
gt = torch.where(gt > 0.5, 1., .0)
176+
177+
return (img, resized_organized_pc, resized_depth_map_3channel), gt[:1], label, rgb_path
178+
179+
180+
def get_data_loader(split, class_name, img_size, args):
181+
if split in ['train']:
182+
dataset = TrainDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path)
183+
elif split in ['test']:
184+
dataset = TestDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path)
185+
186+
data_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=1, drop_last=False,
187+
pin_memory=True)
188+
return data_loader

engine_fusion_pretrain.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import math
2+
import sys
3+
from typing import Iterable
4+
5+
import torch
6+
7+
import utils.misc as misc
8+
import utils.lr_sched as lr_sched
9+
10+
11+
def train_one_epoch(model: torch.nn.Module,
12+
data_loader: Iterable, optimizer: torch.optim.Optimizer,
13+
device: torch.device, epoch: int, loss_scaler,
14+
log_writer=None,
15+
args=None):
16+
model.train(True)
17+
metric_logger = misc.MetricLogger(delimiter=" ")
18+
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
19+
header = 'Epoch: [{}]'.format(epoch)
20+
print_freq = 20
21+
22+
accum_iter = args.accum_iter
23+
24+
optimizer.zero_grad()
25+
26+
if log_writer is not None:
27+
print('log_dir: {}'.format(log_writer.log_dir))
28+
29+
for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
30+
31+
# we use a per iteration (instead of per epoch) lr scheduler
32+
if data_iter_step % accum_iter == 0:
33+
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
34+
35+
36+
xyz_samples = samples[:,:,:1152].to(device, non_blocking=True)
37+
rgb_samples = samples[:,:,1152:].to(device, non_blocking=True)
38+
39+
with torch.cuda.amp.autocast():
40+
loss = model(xyz_samples, rgb_samples)
41+
42+
loss_value = loss.item()
43+
44+
if not math.isfinite(loss_value):
45+
print("Loss is {}, stopping training".format(loss_value))
46+
sys.exit(1)
47+
48+
loss /= accum_iter
49+
loss_scaler(loss, optimizer, parameters=model.parameters(),
50+
update_grad=(data_iter_step + 1) % accum_iter == 0)
51+
if (data_iter_step + 1) % accum_iter == 0:
52+
optimizer.zero_grad()
53+
54+
torch.cuda.synchronize()
55+
56+
metric_logger.update(loss=loss_value)
57+
58+
lr = optimizer.param_groups[0]["lr"]
59+
metric_logger.update(lr=lr)
60+
61+
62+
loss_value_reduce = misc.all_reduce_mean(loss_value)
63+
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
64+
""" We use epoch_1000x as the x-axis in tensorboard.
65+
This calibrates different curves when batch size changes.
66+
"""
67+
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
68+
log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
69+
log_writer.add_scalar('lr', lr, epoch_1000x)
70+
71+
72+
# gather the stats from all processes
73+
metric_logger.synchronize_between_processes()
74+
print("Averaged stats:", metric_logger)
75+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

0 commit comments

Comments
 (0)