Ecosyste.ms: Awesome

An open API service indexing awesome lists of open source software.

Awesome Lists | Featured Topics | Projects

https://github.com/MediaBrain-SJTU/RegAD

[ECCV2022 Oral] Registration based Few-Shot Anomaly Detection
https://github.com/MediaBrain-SJTU/RegAD

Last synced: about 2 months ago
JSON representation

[ECCV2022 Oral] Registration based Few-Shot Anomaly Detection

Awesome Lists containing this project

README

        

# Registration based Few-Shot Anomaly Detection

This is an official implementation of “Registration based Few-Shot Anomaly Detection” (RegAD) with PyTorch, accepted by ECCV 2022 (Oral).

[Paper Link](https://arxiv.org/abs/2207.07361)

```
@inproceedings{huang2022regad,
title={Registration based Few-Shot Anomaly Detection}
author={Huang, Chaoqin and Guan, Haoyan and Jiang, Aofan and Zhang, Ya and Spratlin, Michael and Wang, Yanfeng},
booktitle={European Conference on Computer Vision (ECCV)},
year={2022}
}
```

**Abstract**: This paper considers few-shot anomaly detection (FSAD), a practical yet under-studied setting for anomaly detection (AD), where only a limited number of normal images are provided for each category at training. So far, existing FSAD studies follow the one-model-per-category learning paradigm used for standard AD, and the inter-category commonality has not been explored. Inspired by how humans detect anomalies, i.e., comparing an image in question to normal images, we here leverage registration, an image alignment task that is inherently generalizable across categories, as the proxy task, to train a category-agnostic anomaly detection model. During testing, the anomalies are identified by comparing the registered features of the test image and its corresponding support (normal) images. As far as we know, this is the first FSAD method that trains a single generalizable model and requires no re-training or parameter fine-tuning for new categories.

**Keywords**: Anomaly Detection, Few-Shot Learning, Registration

## Get Started

### Environment
- python >= 3.7.11
- pytorch >= 1.11.0
- torchvision >= 0.12.0
- numpy >= 1.19.5
- scipy >= 1.7.3
- skimage >= 0.19.2
- matplotlib >= 3.5.2
- kornia >= 0.6.5
- tqdm

### Files Preparation

1. Download the MVTec dataset [here](https://www.mvtec.com/company/research/datasets/mvtec-ad).
2. Download the support dataset for few-shot anomaly detection on [Google Drive](https://drive.google.com/file/d/1AZcc77cmDfkWA8f8cs-j-CUuFFQ7tPoK/view?usp=sharing) or [Baidu Disk](https://pan.baidu.com/s/1GZAqtscOaPliaFCiSKlViA) (i9rx)
and unzip the dataset. For those who have problem downloading the support set, please optional download categories of capsule and grid on [Baidu Disk](https://pan.baidu.com/s/1fFwAB__bV0ja38B4w3JnXQ) (pll9) and [Baidu Disk](https://pan.baidu.com/s/1_--hXPPnlv3Tv7HHd4HRZQ) (ns0n).
```
tar -xvf support_set.tar
```
We hope the followers could use these support datasets to make a fair comparison between different methods.
3. Download the pre-train models on [Google Drive](https://drive.google.com/file/d/1guZBh40btPRmxcnY_lud88V1NoT-eWWX/view?usp=sharing) or [Baidu Disk](https://pan.baidu.com/s/1w7-6zicbZA6ysHMSpTHNhg) (4qyo)
and unzip the checkpoint files.
```
tar -xvf save_checkpoints.tar
```
After the preparation work, the whole project should have the following structure:

```
./RegAD
├── README.md
├── train.py # training code
├── test.py # testing code
├── MVTec # MVTec dataset files
│ ├── bottle
│ ├── cable
│ ├── ...
│ └── zippper
├── support_set # MVTec support dataset files
│ ├── 2
│ ├── 4
│ └── 8
├── models # models and backbones
│ ├── stn.py
│ └── siamese.py
├── losses # losses
│ └── norm_loss.py
├── datasets # dataset
│ └── mvtec.py
├── save_checkpoints # model checkpoint files
└── utils # utils
├── utils.py
└── funcs.py
```

### Quick Start

```python
python test.py --obj $target-object --shot $few-shot-number --stn_mode rotation_scale
```

For example, if run on the category `bottle` with `k=2`:
```python
python test.py --obj bottle --shot 2 --stn_mode rotation_scale
```

## Training

```python
python train.py --obj $target-object --shot $few-shot-number --data_type mvtec --data_path ./MVTec/ --epochs 50 --batch_size 32 --lr 0.0001 --momentum 0.9 --inferences 10 --stn_mode rotation_scale
```

For example, to train a RegAD model on the MVTec dataset on `bottle` with `k=2`, simply run:

```python
python train.py --obj bottle --shot 2 --data_type mvtec --data_path ./MVTec/ --epochs 50 --batch_size 32 --lr 0.0001 --momentum 0.9 --inferences 10 --stn_mode rotation_scale
```

Then you can run the evaluation using:
```python
python test.py --obj bottle --shot 2 --stn_mode rotation_scale
```

## Results

Results of few-shot anomaly detection and localization with k=2:

AUC (%) Detection Localization
K=2 RegAD Inplementation RegAD Inplementation

bottle
99.4
99.7
98.0
98.6


cable
65.1
69.8
91.7
94.2


capsule
67.5
68.6
97.3
97.6


carpet
96.5
96.7
98.9
98.9


grid
84.0
79.1
77.4
77.5


hazelnut
96.0
96.3
98.1
98.2


leather
99.4
100
98.0
99.2


metal_nut
91.4
94.2
96.9
98.0


pill
81.3
66.1
93.6
97.0


screw
52.5
53.9
94.4
94.1


tile
94.3
98.9
94.3
95.1


toothbrush
86.6
86.8
98.2
98.2


transistor
86.0
82.2
93.4
93.3


wood
99.2
99.8
93.5
96.5


zipper
86.3
90.9
95.1
98.3


average
85.7
85.5
94.6
95.6

Results of few-shot anomaly detection and localization with k=4:

AUC (%) Detection Localization
K=4 RegAD Inplementation RegAD Inplementation

bottle
99.4
99.3
98.4
98.5


cable
76.1
82.9
92.7
95.5


capsule
72.4
77.3
97.6
98.3


carpet
97.9
97.9
98.9
98.9


grid
91.2
87
85.7
85.7


hazelnut
95.8
95.9
98.0
98.4


leather
100
99.9
99.1
99


metal_nut
94.6
94.3
97.8
96.5


pill
80.8
74.0
97.4
97.4


screw
56.6
59.3
95.0
96.0


tile
95.5
98.2
94.9
92.6


toothbrush
90.9
91.1
98.5
98.5


transistor
85.2
85.5
93.8
93.5


wood
98.6
98.9
94.7
96.3


zipper
88.5
95.8
94.0
98.6


average
88.2
89.2
95.8
96.2

Results of few-shot anomaly detection and localization with k=8:

AUC (%) Detection Localization
K=8 RegAD Inplementation RegAD Inplementation

bottle
99.8
99.8
97.5
98.5


cable
80.6
81.5
94.9
95.8


capsule
76.3
78.4
98.2
98.4


carpet
98.5
98.6
98.9
98.9


grid
91.5
91.5
88.7
88.7


hazelnut
96.5
97.3
98.5
98.5


leather
100
100
98.9
99.3


metal_nut
98.3
98.6
96.9
98.3


pill
80.6
77.8
97.8
97.7


screw
63.4
65.8
97.1
97.3


tile
97.4
99.6
95.2
96.1


toothbrush
98.5
96.6
98.7
99.0


transistor
93.4
90.3
96.8
95.9


wood
99.4
99.5
94.6
96.5


zipper
94.0
93.4
97.4
97.4


average
91.2
91.2
96.8
97.1

## Visualization

## Acknowledgement
We borrow some codes from [SimSiam](https://github.com/facebookresearch/simsiam), [STN](https://github.com/YotYot/CalibrationNet/blob/2446a3bcb7ff4aa1e492adcde62a4b10a33635b4/models/configurable_stn_no_stereo.py) and [PaDiM](https://github.com/xiahaifeng1995/PaDiM-Anomaly-Detection-Localization-master)

## Contact

If you have any problem with this code, please feel free to contact **[email protected]**.