https://github.com/kiselyovd/brain-mri-segmentation
Binary brain-tumor MRI segmentation (LGG) — SegFormer-B2 + U-Net baseline, FastAPI serving, HF Hub model
https://github.com/kiselyovd/brain-mri-segmentation
brain-mri deep-learning fastapi medical-imaging pytorch pytorch-lightning segformer segmentation semantic-segmentation unet
Last synced: 3 months ago
JSON representation
Binary brain-tumor MRI segmentation (LGG) — SegFormer-B2 + U-Net baseline, FastAPI serving, HF Hub model
- Host: GitHub
- URL: https://github.com/kiselyovd/brain-mri-segmentation
- Owner: kiselyovd
- License: mit
- Created: 2023-07-12T09:00:32.000Z (almost 3 years ago)
- Default Branch: main
- Last Pushed: 2026-04-14T18:02:44.000Z (3 months ago)
- Last Synced: 2026-04-14T20:07:45.367Z (3 months ago)
- Topics: brain-mri, deep-learning, fastapi, medical-imaging, pytorch, pytorch-lightning, segformer, segmentation, semantic-segmentation, unet
- Language: Python
- Homepage: https://kiselyovd.github.io/brain-mri-segmentation/
- Size: 1.87 MB
- Stars: 0
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# brain-mri-segmentation
[](https://github.com/kiselyovd/brain-mri-segmentation/actions/workflows/ci.yml)
[](https://kiselyovd.github.io/brain-mri-segmentation/)
[](LICENSE)
[](https://www.python.org/)
[](https://huggingface.co/kiselyovd/brain-mri-segmentation)
Binary brain-tumor segmentation on MRI slices — fine-tuned **SegFormer-B2** as the main model and a hand-rolled **U-Net** as a reproducible baseline, both trained on the Mateusz Buda LGG (TCGA) dataset with a strict patient-level split to prevent data leakage.
**Russian:** [README.ru.md](README.ru.md) · **Docs:** [kiselyovd.github.io/brain-mri-segmentation](https://kiselyovd.github.io/brain-mri-segmentation/) · **Model:** [kiselyovd/brain-mri-segmentation](https://huggingface.co/kiselyovd/brain-mri-segmentation)
## Dataset
Mateusz Buda's [LGG MRI Segmentation](https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation) on Kaggle — 110 patients, 3,929 paired FLAIR slices with binary tumor masks from The Cancer Genome Atlas (TCGA). `src/brain_mri_segmentation/data/prepare.py` performs a **patient-level 80/10/10 split** (88/11/11 patients), so no patient appears in more than one partition.
Resulting slice counts: **3,133 train / 409 val / 387 test**.
## Results
Test-set metrics after full training (fill in with real numbers from `reports/metrics.json`):
| Model | Dice | IoU | Pixel Accuracy |
|---|---|---|---|
| **SegFormer-B2** (main) | **65.5%** | **66.2%** | **99.73%** |
| U-Net 4-level baseline | 51.9% | 57.7% | 99.66% |
Full per-slice report lives in `reports/metrics.json` after running evaluation.
## Quick Start
```bash
# 1. Install
uv sync --all-groups
# 2. Sync Kaggle dataset into data/raw/ (once)
bash scripts/sync_data.sh /path/to/lgg-mri-segmentation
# 3. Build processed splits
uv run python -m brain_mri_segmentation.data.prepare --raw data/raw --out data/processed
# 4. Train (main model on GPU)
make train
# 5. Evaluate on test split
make evaluate
# 6. Serve the model locally
make serve
# or
docker compose up api
```
## Full Training Commands
**Main — SegFormer-B2:**
```bash
uv run python -m brain_mri_segmentation.training.train experiment=sota
```
**Baseline — U-Net (4 levels, 32→256 ch):**
```bash
uv run python -m brain_mri_segmentation.training.train \
model=baseline \
trainer.max_epochs=30 \
trainer.output_dir=artifacts/baseline
```
Every run is tracked with MLflow under `./mlruns/`; launch `mlflow ui --backend-store-uri ./mlruns` to inspect.
## Inference
```python
from huggingface_hub import snapshot_download
from brain_mri_segmentation.inference.predict import load_model, predict
weights_dir = snapshot_download("kiselyovd/brain-mri-segmentation")
model = load_model(f"{weights_dir}/best.ckpt")
result = predict(model, "slice.tif")
print(f"Mask: {len(result['mask'])}×{len(result['mask'][0])}")
```
`result["mask"]` is a 2-D binary array (H × W) aligned to the input slice.
## Serving
```bash
docker compose up api
curl -X POST -F "file=@slice.tif" http://localhost:8000/segment
```
Endpoints:
| Method | Path | Purpose |
|---|---|---|
| `GET` | `/health` | Liveness probe |
| `POST` | `/segment` | Multipart TIFF/PNG → JSON mask |
| `GET` | `/metrics` | Prometheus metrics |
Every response carries an `X-Request-ID` header for log correlation.
## Project Structure
```
src/brain_mri_segmentation/
├── data/ # MRIDataModule, MRIDataset, prepare.py (patient-level split)
├── models/ # factory.py, lightning_module.py, unet.py
├── training/ # Hydra entrypoint
├── evaluation/ # Dice / IoU / pixel-accuracy report
├── inference/ # load_model + predict
├── serving/ # FastAPI app
└── utils/ # logging, seeding, HF Hub helpers
configs/ # Hydra configs (data / model / trainer / experiment)
data/
├── raw/ # original Kaggle download (images + masks per patient)
└── processed/ # train / val / test splits
docs/ # MkDocs site sources
tests/ # pytest suite
```
## Intended Use
Research and educational only. **Not a medical device.** Predictions must not be used for clinical decisions.
## License
MIT — see [LICENSE](LICENSE).