{"id":21341602,"url":"https://github.com/marksgraham/transformer-ood","last_synced_at":"2025-07-12T15:30:29.790Z","repository":{"id":166625681,"uuid":"600225917","full_name":"marksgraham/transformer-ood","owner":"marksgraham","description":"Official PyTorch code for \"Transformer-based out-of-distribution detection for clinically safe segmentation\"","archived":false,"fork":false,"pushed_at":"2023-10-06T17:40:27.000Z","size":55,"stargazers_count":4,"open_issues_count":0,"forks_count":0,"subscribers_count":2,"default_branch":"main","last_synced_at":"2024-03-18T01:01:14.543Z","etag":null,"topics":["out-of-distribution-detection","pytorch","transformers","vq-vae"],"latest_commit_sha":null,"homepage":"https://proceedings.mlr.press/v172/graham22a.html","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"other","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/marksgraham.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":null,"code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null}},"created_at":"2023-02-10T21:43:56.000Z","updated_at":"2024-02-23T07:15:13.000Z","dependencies_parsed_at":null,"dependency_job_id":"c82ed7f0-3923-4ec0-b109-e885cf57db41","html_url":"https://github.com/marksgraham/transformer-ood","commit_stats":null,"previous_names":["marksgraham/transformer-ood"],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/marksgraham%2Ftransformer-ood","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/marksgraham%2Ftransformer-ood/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/marksgraham%2Ftransformer-ood/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/marksgraham%2Ftransformer-ood/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/marksgraham","download_url":"https://codeload.github.com/marksgraham/transformer-ood/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":225824739,"owners_count":17529906,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["out-of-distribution-detection","pytorch","transformers","vq-vae"],"created_at":"2024-11-22T00:57:37.217Z","updated_at":"2024-11-22T00:57:38.907Z","avatar_url":"https://github.com/marksgraham.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# Out-of-distribution detection with Transformers\n\nThis repo shows how a VQ-GAN + Transformer can be trained to perform unsupervised OOD detection on 3D medical data. It uses the freely available [Medical Decathlon dataset](http://medicaldecathlon.com/) for its experiments.\n\nThe method is fully  described in the MIDL 2022 paper [Transformer-based out-of-distribution detection for clinically safe segmentation](https://proceedings.mlr.press/v172/graham22a).\n\n### Set-up\nCreate a fresh environment (this codebase was developed and tested with Python 3.8) and then install the required packages:\n\n```pip install -r requirements.txt```\n\nYou can also build the docker image\n```bash\ncd docker/\nbash create_docker_image.sh\n```\n\n### Download and prepare the data\nDownload all classes of the Medical Decathlon dataset. NB: this will take a while.\n\n```bash\nexport DATA_ROOT=/desired/path/to/data\npython src/data/get_decathlon_datasets.py --data_root=${DATA_ROOT}\n```\n\n\n### Train the VQ-GAN\nWe will treat the BRATs data as the in-distribution class and trained a VQ-GAN on it.\n\nFirst set up your output dir:\n```bash\nexport OUTPUT_DIR=/desired/path/to/output\n```\n\nThen train the VQ-GAN:\n```bash\npython train_vqvae.py \\\n--output_dir=${OUTPUT_DIR} \\\n--model_name=vqgan_decathlon \\\n--training_dir=${DATA_ROOT}/data_splits/Task01_BrainTumour_train.csv \\\n--validation_dir=${DATA_ROOT}/data_splits/Task01_BrainTumour_val.csv \\\n--n_epochs=300 \\\n--batch_size=8 \\\n--eval_freq=10 \\\n--cache_data=1 \\\n--vqvae_downsample_parameters=[[2,4,1,1],[2,4,1,1],[2,4,1,1],[2,4,1,1]] \\\n--vqvae_upsample_parameters=[[2,4,1,1,0],[2,4,1,1,0],[2,4,1,1,0],[2,4,1,1,0]] \\\n--vqvae_num_channels=[256,256,256,256] \\\n--vqvae_num_res_channels=[256,256,256,256] \\\n--vqvae_embedding_dim=128 \\\n--vqvae_num_embeddings=2048 \\\n--spatial_dimension=3 \\\n--image_roi=[160,160,128] \\\n--image_size=128 \\\n--num_workers=4\n```\n\nCode is DDP compatible. To train with N GPus, train with:\n\n`torchrun --nproc_per_node=N --nnodes=1 --node_rank=0 train_vqvae.py --args`\n\n### Train the transformer\n```bash\npython train_transformer.py \\\n--output_dir=${OUTPUT_DIR}\n--model_name=transformer_decathlon\n--training_dir=${DATA_ROOT}/data_splits/Task01_BrainTumour_train.csv \\\n--validation_dir=${DATA_ROOT}/data_splits/Task01_BrainTumour_val.csv \\\n--n_epochs=100 \\\n--batch_size=4 \\\n--eval_freq=10 \\\n--cache_data=1 \\\n--image_roi=[160,160,128] \\\n--image_size=128 \\\n--num_workers=4\n--vqvae_checkpoint=${OUTPUT_DIR}/vqgan_decathlon/checkpoint.pth\n--transformer_type=transformer\n```\n### Evaluate\nGet likelihoods on the test set of the BRATs dataset, and on the test sets of the other 9 classes of the Medical Decathlon dataset.\n```bash\npython perform_ood.py \\\n--output_dir=${OUTPUT_DIR} \\\n--model_name=transformer_decathlon \\\n--training_dir=${DATA_DIR}/data_splits/Task01_BrainTumour_train.csv \\\n--validation_dir=${DATA_DIR}/data_splits/Task01_BrainTumour_val.csv \\\n--ood_dir=${DATA_DIR}/data_splits/Task01_BrainTumour_test.csv,${DATA_DIR}/data_splits/Task02_Heart_test.csv,${DATA_DIR}/data_splits/Task04_Hippocampus_test.csv,${DATA_DIR}/data_splits/Task05_Prostate_test.csv,${DATA_DIR}/data_splits/Task06_Lung_test.csv,${DATA_DIR}/data_splits/Task07_Pancreas_test.csv,${DATA_DIR}/data_splits/Task08_HepaticVessel_test.csv,${DATA_DIR}/data_splits/Task09_Spleen_test.csv,${DATA_DIR}/data_splits/Task10_Colon_test.csv\n--cache_data=0\n--batch_size=2\n--num_workers=0\n--vqvae_checkpoint=${OUTPUT_DIR}/vqgan_decathlon/checkpoint.pth\n--transformer_checkpoint=${OUTPUT_DIR}/transformer_decathlon/checkpoint.pth\n--transformer_type=transformer\n--image_roi=[160,160,128]\n--image_size=128\n```\n\nYou can print the results with:\n```bash\npython print_ood_scores.py --results_file=${OUTPUT_DIR}/transformer_decathlon/results.csv\n```\n\n### Acknowledgements\nBuilt on top of [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels) and [MONAI](https://github.com/Project-MONAI/MONAI).\n\n### Citation\nIf you use these codebase, please cite the following paper:\n```\n@inproceedings{graham2022transformer,\n  title={Transformer-based out-of-distribution detection for clinically safe segmentation},\n  author={Graham, Mark S and Tudosiu, Petru-Daniel and Wright, Paul and Pinaya, Walter Hugo Lopez and Jean-Marie, U and Mah, Yee H and Teo, James T and Jager, Rolf and Werring, David and Nachev, Parashkev and others},\n  booktitle={International Conference on Medical Imaging with Deep Learning},\n  pages={457--476},\n  year={2022},\n  organization={PMLR}\n}\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fmarksgraham%2Ftransformer-ood","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fmarksgraham%2Ftransformer-ood","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fmarksgraham%2Ftransformer-ood/lists"}