Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/Zhicaiwww/Diff-Mix
https://github.com/Zhicaiwww/Diff-Mix
Last synced: 3 months ago
JSON representation
- Host: GitHub
- URL: https://github.com/Zhicaiwww/Diff-Mix
- Owner: Zhicaiwww
- Created: 2023-12-26T09:27:17.000Z (11 months ago)
- Default Branch: master
- Last Pushed: 2024-06-10T09:37:05.000Z (5 months ago)
- Last Synced: 2024-06-10T11:27:05.609Z (5 months ago)
- Language: Python
- Size: 8.88 MB
- Stars: 11
- Watchers: 1
- Forks: 1
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
- Awesome-Mixup - [Code
README
# Enhance Image Classification Via Inter-Class Image Mixup With Diffusion Model
## Introduction
Currently, a common method to enhance image classification involves expanding the training set with synthetic datasets generated by T2I models.
Here, we propose an inter-class data augmentation method, Diff-Mix.
Diff-Mix expands the dataset by conducting image translation in an inter-class manner, significantly improving the diversity of synthetic data.
We observe an improved trade-off between faithfulness and diversity with Diff-Mix, resulting in a significant performance gain across various image classification settings, including few-shot classification, conventional classification, and long-tail classification, particularly for domain-specific datasets.## Datasets
For convenience, well-structured datasets in Hugging Face can be utilized. The fine-grained datasets `CUB` and `Aircraft` we experimented with can be downloaded from [Multimodal-Fatima/CUB_train](https://huggingface.co/datasets/Multimodal-Fatima/CUB_train) and [Multimodal-Fatima/FGVC_Aircraft_train](https://huggingface.co/datasets/Multimodal-Fatima/FGVC_Aircraft_train), respectively. In case of encountering network connection problem during training, please pre-download the data from the website, and the saved local path `HUG_LOCAL_IMAGE_TRAIN_DIR` should be specified in the `semantic_aug/datasets/cub.py`.## Code Description
### Fine-tuning
We fine-tune both the textual tokens and U-Net (LoRA) ([diffusers](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py)) of the pre-trained Stable Diffusion to expedite the fine-tuning process.To simplify the usage, the concrete fine-tuning command is wrapped in the script `scripts/finetune.sh`. The distributed training is performed using the `accelerate` tool, and the GPU should be specified using the environmental variable `CUDA_VISIBLE_DEVICES`. The simplified command for fine-tuning on the full training set of `CUB` with a total of `35000` steps is:
```bash
source scripts/finetune.sh
bash finetune 'cub' 'ti_db' -1 35000
```
To fine-tune in a 5-shot setting, modify the shell command to```bash
source scripts/finetune.sh
bash finetune 'cub' 'ti_db' 5 35000
```
The fine-tuned checkpoints will be saved under `outputs/finetune_model/finetune_ti_db{_5shot}/cub/`. After that, please manually add the meta information of checkpoints into `config/finetuned_ckpts.yaml` constructed with the following format:```yaml
cub:
ti_db_latest:
model_path: "runwayml/stable-diffusion-v1-5"
lora_path: "outputs/finetune_model/finetune_ti_db/sd-cub-model-lora-rank10/checkpoint-35000/pytorch_model.bin"
embed_path: "outputs/finetune_model/finetune_ti_db/sd-cub-model-lora-rank10/learned_embeds-steps-35000.bin"
```
This structure allows you to locate the checkpoint paths simply by using the key set ('cub', 'ti_db_latest').### Contruct synthetic data
Similarly, we wrap the command details in the file `scripts/sample.sh`. To expedite the inference process, we utilize the `multiprocessing` tool to initiate multiple inference processes. The desired processes should be specified using the defined environmental variable `GPU_IDS`, where each item in the list denotes the process running on the indexed GPU.The simplified command for sampling a $5\times$ synthetic subset in an inter-class translation manner (Diff-Mix) with strength $s=0.7$ is:
```bash
source scripts/sample.sh
export GPU_IDS=(0 0 0 1 1 1)
bash sample 'cub' 'ti_db_latest' 'diff-mix' 0.7
```
One can also attempt to construct the synthetic subset using other expansion strategies by replacing `diff-mix` with `diff-aug` (Diff-Aug, fine-tuned intra-class translation method), `real-mix` (Real-Mix, pre-trained inter-class translation method), `real-guidance` (Real-Aug, pre-trained intra-class translation method).To sample a 5-shot setting, modify the shell command to:
```
source scripts/sample.sh
export GPU_IDS=(0 0 0 1 1 1)
bash sample_fewshot 5 'cub' '5shot_ti_db_latest' 'diff-mix' 0.7
```
The sampled subset will be cached at `outputs/aug_samples{_5shot}/cub`. After that, please manually add the meta-information of the subset into `synthetic_datasets.yaml` constructed with the form:```yaml
cub:
diffmix_0.7: 'outputs/aug_samples/cub/diff-mix-Multi7-ti_db35000-Strength0.7'
5shot_diffmix_0.7: 'outputs/aug_samples_5shot/cub/diff-mix-Multi7-ti_db35000-Strength0.7'
```
This allows you to locate the synthetic paths simply by using the key set ('cub', 'diffmix_fixed_0.7') in case there are multiple subsets.
### Downstream classification
After completing the sampling process, you can integrate the synthetic data into downstream classification and initiate training using the following commands:```bash
source scripts/classification.sh
# main_cls {dataset_name} {gpu} {seed} {model} {resolution} {nepoch} {syndata_key} {gamma} {synthetic_prob}
main_cls 'cub' '0' 2020 'resnet50' '224' 120 'diffmix_0.7' 0.5 0.1
```Running scripts
## Acknowledgements
This project is built upon the repository [Da-fusion](https://github.com/brandontrabucco/da-fusion) and [diffusers](https://github.com/huggingface/diffusers). Special thanks to the contributors.
## Requirements