https://github.com/yueyang1996/knobo
NeurIPS 2024 (spotlight): A Textbook Remedy for Domain Shifts Knowledge Priors for Medical Image Analysis
https://github.com/yueyang1996/knobo
confounding domain-shift interpretable-machine-learning medical retrieval-augmented-generation
Last synced: 2 months ago
JSON representation
NeurIPS 2024 (spotlight): A Textbook Remedy for Domain Shifts Knowledge Priors for Medical Image Analysis
- Host: GitHub
- URL: https://github.com/yueyang1996/knobo
- Owner: YueYANG1996
- License: mit
- Created: 2024-05-23T03:43:37.000Z (about 2 years ago)
- Default Branch: main
- Last Pushed: 2024-10-15T18:54:20.000Z (over 1 year ago)
- Last Synced: 2025-10-14T10:25:11.336Z (8 months ago)
- Topics: confounding, domain-shift, interpretable-machine-learning, medical, retrieval-augmented-generation
- Language: Python
- Homepage: https://yueyang1996.github.io/knobo/
- Size: 1000 KB
- Stars: 26
- Watchers: 1
- Forks: 3
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
A Textbook Remedy for Domain Shifts
Knowledge Priors for Medical Image Analysis
Paper | Project Page
## Table of Contents
1. [CLIP Models](#clip-models)
2. [Installation](#installation)
3. [Quick Start](#quick-start)
4. [Directories](#directories)
5. [Extract Features](#extract-features)
6. [Generate Bottlenecks from Medical Documents](#generate-bottlenecks-from-medical-documents)
7. [Train Grounding Functions](#train-grounding-functions)
8. [Baselines](#baselines)
## CLIP Models
We release the two CLIP models we trained for X-ray and Skin Lesion images on huggingface.
* **WhyXrayCLIP** 🩻 : https://huggingface.co/yyupenn/whyxrayclip
* **WhyLesionCLIP** 👍🏽 : https://huggingface.co/yyupenn/whylesionclip
## Installation
After cloning the repo, you can install the required dependencies and download the data by running the following commands:
```bash
git clone https://github.com/YueYANG1996/KnoBo.git
cd KnoBo
sh setup.sh
```
## Quick Start
To get the results of KnoBo on X-ray datasets, you can run the following command:
```bash
python modules/cbm.py \
--mode binary \
--bottleneck PubMed \
--number_of_features 150 \
--add_prior True \
--modality xray \
--model_name whyxrayclip \
```
The output will be saved to `./data/results/`. You can change the `--modality` to `skin` and `--model_name` to `whylesionclip` to get the results on Skin Lesion datasets.
## Directories
* `data/`: Contains the data for all experiments.
- `data/bottlenecks/`: Contains the concept bottleneck created using medical documents.
- `data/datasets/`: This contains the splits for all datasets. You may need to download the images of each dataset from its original sources. Please refer to the [DATASETS.md](DATASETS.md) for more details.
- `data/features/`: Contains the features extracted from different models.
- `data/grounding_functions/`: Contains the grounding functions for each concept in the bottleneck.
- `data/results/`: Contains the results of all experiments.
* `modules/`: Contains the scripts for all experiments.
- [`modules/cbm.py`](modules/cbm.py): Contains the script for the running linear-based models, including KnoBo, linear probing, and PCBM.
- [`modules/extract_features.py`](modules/extract_features.py): Contains the script for extracting image features using different models.
- [`modules/train_grounding.py`](modules/train_grounding.py): Contains the script for training the grounding functions for each concept in the bottleneck.
- [`modules/end2end.py`](modules/end2end.py) : Contains the script for training the end-to-end model, including ViT and DenseNet.
- [`modules/LSL.py`](modules/LSL.py): Contains the script for fine-tuning CLIP with knowledge (Language-shaped Learning).
- [`modules/models.py`](modules/models.py) : Contains the models used in the experiments.
- [`modules/utils.py`](modules/utils.py) : Contains the utility functions.
## Extract Features
After running the [`setup.sh`](setup.sh), you should have the features extracted from the two CLIP models we trained in the `data/features/` directory. If you want to extract features using other models, you can run the following command:
```bash
python modules/extract_features.py \
--dataset_name \
--model_name \
--image_dir \
```
The supported models are listed [here](https://github.com/YueYANG1996/KnoBo/blob/e3e3171b74b6c8f42046676aa6c6ae21a034deba/modules/extract_features.py#L141). We provide a bash script [`extract_features.sh`](extract_features.sh) to extract features for all datasets using the two CLIP models we trained.
## Generate Bottlenecks from Medical Documents
We build the retrieval-based concept bottleneck generation pipeline based on [MedRAG](https://arxiv.org/pdf/2402.13178). You need to first clone our [forked version](https://github.com/YueYANG1996/MedRAG/tree/main) and set up the environment by running the following commands:
```bash
git clone https://github.com/YueYANG1996/MedRAG.git
cd MedRAG
sh setup.sh
```
It may take a while since it needs to download the 5M PubMed documents (29.5 GB). After setting up the environment, you can test the RAG system by running the [`test.py`](https://github.com/YueYANG1996/MedRAG/blob/main/test.py).
To generate the concept bottleneck from medical documents, you can run the following command:
```bash
python concept_generation.py \
--modality \
--corpus_name \
--number_of_concepts \
--openai_key \
```
For the `--corpus_name,` you can choose from `PubMed_all` (this is our version of PubMed with all paragraphs), `PubMed` (this is MedRAG's original version of PubMed, which only has abstracts), `Textbooks,` `StatPearls` and `Wikipedia`. The generated bottleneck will be saved to `./data/bottlenecks/__.txt`.
**Annotate concepts:** You can annotate clinical reports for each concept in the bottleneck by running the following command:
```bash
python annotate_question.py \
--annotator \
--modality \
--bottleneck \
--number_of_reports \
--openai_key \
```
The default LLM for annotation is [Flan-T5-XXL](https://huggingface.co/google/flan-t5-xxl). You can change it to GPT-4 by setting `--annotator gpt4` (warning: this may cost a lot of money). The default number of reports to annotate is 1000. The annotated reports will be saved to `./data/concept_annotation_/annotations_/`.
## Train Grounding Functions
To train the grounding functions for each concept in the bottleneck, you can run the following command:
```bash
python modules/train_grounding.py \
--modality \
--bottleneck \
```
Each grounding function is a binary classifier that predicts whether the concept is present in the image. The output will be saved to `./data/grounding_functions///`.
## Baselines
* **Linear Probing**: `python modules/cbm.py --mode linear_probe --modality --model_name `.
* **PCBM-h**: `python modules/cbm.py --mode pcbm --bottleneck PubMed --number_of_features 150 --modality --model_name `.
* **End-to-End**: `python modules/end2end.py --modality --model_name `.
* **LSL**: You need to first fine-tune the CLIP model with knowledge using the following command:
```bash
python modules/LSL.py \
--modality \
--clip_model_name \
--bottleneck \
--image_dir \
```
Then, extract the features using the fine-tuned CLIP model and get the final results same as linear probing: `python modules/cbm.py --mode linear_probe --modality --model_name `. We provide the models we fine-tuned on PubMed in the `data/model_weights/` directory.
## Citation
Please cite our paper if you find our work useful!
```bibtex
@article{yang2024textbook,
title={A Textbook Remedy for Domain Shifts: Knowledge Priors for Medical Image Analysis},
author={Yue Yang and Mona Gandhi and Yufei Wang and Yifan Wu and Michael S. Yao and Chris Callison-Burch and James C. Gee and Mark Yatskar},
journal={arXiv preprint arXiv:2405.14839},
year={2024}
}
```