{"id":22068360,"url":"https://github.com/mazurowski-lab/finetune-SAM","last_synced_at":"2025-07-24T06:30:57.152Z","repository":{"id":233363297,"uuid":"784347564","full_name":"mazurowski-lab/finetune-SAM","owner":"mazurowski-lab","description":"This is an official repo for fine-tuning SAM to customized medical images.","archived":false,"fork":false,"pushed_at":"2024-10-18T17:35:47.000Z","size":1687,"stargazers_count":205,"open_issues_count":8,"forks_count":35,"subscribers_count":3,"default_branch":"main","last_synced_at":"2025-06-14T00:31:12.829Z","etag":null,"topics":["finetune","foundation-models","medical-imaging","sam"],"latest_commit_sha":null,"homepage":"https://arxiv.org/abs/2404.09957","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/mazurowski-lab.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2024-04-09T17:09:33.000Z","updated_at":"2025-06-10T02:47:33.000Z","dependencies_parsed_at":"2024-08-19T19:15:58.229Z","dependency_job_id":"b33eb503-3021-4265-8778-ddf8af964e98","html_url":"https://github.com/mazurowski-lab/finetune-SAM","commit_stats":null,"previous_names":["mazurowski-lab/finetune-sam"],"tags_count":0,"template":false,"template_full_name":null,"purl":"pkg:github/mazurowski-lab/finetune-SAM","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/mazurowski-lab%2Ffinetune-SAM","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/mazurowski-lab%2Ffinetune-SAM/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/mazurowski-lab%2Ffinetune-SAM/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/mazurowski-lab%2Ffinetune-SAM/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/mazurowski-lab","download_url":"https://codeload.github.com/mazurowski-lab/finetune-SAM/tar.gz/refs/heads/main","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/mazurowski-lab%2Ffinetune-SAM/sbom","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":266802637,"owners_count":23986384,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","status":"online","status_checked_at":"2025-07-24T02:00:09.469Z","response_time":99,"last_error":null,"robots_txt_status":null,"robots_txt_updated_at":null,"robots_txt_url":"https://github.com/robots.txt","online":true,"can_crawl_api":true,"host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["finetune","foundation-models","medical-imaging","sam"],"created_at":"2024-11-30T20:03:58.028Z","updated_at":"2025-07-24T06:30:56.132Z","avatar_url":"https://github.com/mazurowski-lab.png","language":"Python","funding_links":[],"categories":["Paper List"],"sub_categories":["Follow-up Papers"],"readme":"# Finetune SAM on your customized medical imaging dataset\nAuthors: [Hanxue Gu*](https://scholar.google.com/citations?hl=en\u0026user=aGjCpQUAAAAJ\u0026view_op=list_works\u0026sortby=pubdate), [Haoyu Dong*](https://scholar.google.com/citations?user=eZVEUCIAAAAJ\u0026hl=en), [Jichen Yang](https://scholar.google.com/citations?user=jGv3bRUAAAAJ\u0026hl=en), [Maciej A. Mazurowski](https://scholar.google.com/citations?user=HlxjJPQAAAAJ\u0026hl=en)\n\n**Notice:** 🥰Hi guys, since my github is not linked to my work email thus i might not reply to issues or questions quickly. Feel free to email me if you meet issues when using this repo, and i am glad to help. Here is my email: hanxue.gu@duke.edu.\n\nThis is the official code for our paper: [How to build the best medical image segmentation algorithm using foundation models: a comprehensive empirical study with Segment Anything Model](https://arxiv.org/abs/2404.09957), where we explore three popular scenarios when fine-tuning foundation models to customized datasets in the medical imaging field: (1) only a single labeled dataset; (2) multiple labeled datasets for different tasks; and (3) multiple labeled and unlabeled datasets; and we design three common experimental setups, as shown in figure 1.\n![Fig1: Overview of general fine-tuning strategies based on different levels of dataset availability.](https://github.com/mazurowski-lab/finetune-SAM/blob/main/finetune_strategy_v9.png)\n\nOur work summarizes and evaluates existing fine-tuning strategies with various backbone architectures,  model components, and fine-tuning algorithms across 18 combinations, and 17 datasets covering all common radiology modalities. \n![Fig2: Visualization of task-specific fine-tuning architectures selected in our study: including 3 encoder architecture $\\times$ 2 model components $\\times$ 3 vanilla/PEFT methods = 18 choices.](https://github.com/mazurowski-lab/finetune-SAM/blob/main/finetune_combination_v3.png)\n\n\nBased on our extensive experiments, we found that:\n1.  fine-tuning SAM leads to slightly better performance than previous segmentation methods.\n2. fine-tuning strategies that use parameter-efficient learning in both the encoder and decoder are superior to other strategies.\n3. network architecture has a small impact on the final performance, \n4. further training SAM with self-supervised learning can improve final model performance.\n\nTo use our codebase, we provide (a) codes to fine-tune your medical imaging dataset on either automatic/prompt-based setting, (b) pretrained weights we got from Setup 3 using task-agnostic self-supervised learning, which we found as good pretrained weights instead of initial SAM providing better performance for downstream tasks.\n\n## Bug fixes:\n- [X] May-10-2024, fixed the bug that when we updated the dataset.py at May 6th for multi class support, the mask resize processing was accidently forgotten.\n- [X] May-10-2024, fixed the bug that the provided demo for single gpu trianing only support updating decoder but the image encoder's gradients were not calculated.\n- [X] June-10-2024, fixed the bug that cfg.py was not updated as the same version of train.sh which didn't include two configs as 'train_img_list' and 'val_img_list'.\n\n## Updated functions:\n- [X] May-15-2024, add functions to auto save training args and load args for validation; save your time for manual definition.\n- [X] May-15-2024, add two jupyter-notebooks showing examples about how to make predictions on 3D volumes/2D pngs without ground truth; and for visualization.\n- [X] May-15-2024, provide two additional example demos.\n- [X] June-10-2024, add spatial transformation choice in dataset.py\n\n\n## a): fine-tune to one single task-specific dataset \n### Step 0: setup environment\nIf using conda enviroment:\n```bash\nconda env create -f environment.yml\n```\nIf directly using pip\n```bash\npip install -r requirements.txt\n```\n### Step 1: dataset preparation.\nPlease prepare your images and mask pairs in 2D slices first. If your original dataset is in 3D format, please preprocess it and save images/masks as 2D slices.\n\nThere is no strict format for your dataset folder; you need first to identify your main dataset folder, for example:\n```\nargs.img_folder = './datasets/'\nargs.mask_folder = './datasets/'\n```\nThen prepare your image/mask list file train/val/test.csv under **args.img_folder/dataset_name/** in the following format: ``img_slice_path mask_slice_path``, such as:\n```\nsa_xrayhip/images/image_044.ni_z001.png\tsa_xrayhip/masks/image_044.ni_z001.png\nsa_xrayhip/images/image_126.ni_z001.png\tsa_xrayhip/masks/image_126.ni_z001.png\nsa_xrayhip/images/image_034.ni_z001.png\tsa_xrayhip/masks/image_034.ni_z001.png\nsa_xrayhip/images/image_028.ni_z001.png\tsa_xrayhip/masks/image_028.ni_z001.png\n```\n## Step 2:\nConfigure your network architectures and other hyperparameters.\n### (1) Choose image encoder architecture.\n```\nargs.arch = 'vit_b' # you can pick from  'vit_h','vit_b','vit_t'\n\n#If load original sam's encoder, for example, if 'vit_b':\nargs.sam_ckpt = \"sam_vit_b_01ec64.pth\" \n# You can replace it with any other pretrained weights, such as 'medsam_vit_b.pth'\n```\nYou need to download SAM's checkpoints of vit-h, and vit-b from [SAM](https://github.com/facebookresearch/segment-anything),  and to use MobileSAM; you can download the checkpoints from [MobileSAM](https://github.com/ChaoningZhang/MobileSAM)\n\n**To be noticed****\nIf pretrained weights are used as MedSAM, you need to use dataset normalization as [0-1] instead of the original SAM's mean/std normations.\n```\n# normalzie_type: 'sam' or 'medsam', if sam, using transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]); if medsam, using [0,1] normalize.\n\nargs.normalize_type = 'medsam'\n```\n\n### (2) Choose fine-tuning Methods.\n\n#### (i) Vanilla fine-tuning\n - If you want to update Encoder and Decoder both, just load the network and put:\n```\nargs.if_update_encoder = True\n```\n - If you only want to update Mask Decoder, just load the network and put:\n```\nargs.if_update_encoder = False\n```\n\n#### (2) fine-tuning using Adapter blocks\n- If you want to add adapter blocks on the image encoder and mask decoder both:\n```\nargs.if_mask_decoder_adapter=True\n\nargs.if_update_encoder = True\nargs.if_encoder_adapter=True\n# You can pick the image encoder blocks by adding adapters\nargs.encoder_adapter_depths = range(0,12)\n```\n- If you want to add adapter blocks to the decoder only:\n```\nargs.if_mask_decoder_adapter=True\n```\n\n#### (3) fine-tuning using LoRA blocks\n-  If you want to add LoRA blocks on the image encoder and mask decoder both:\n```\n# define which blocks you would like to add LoRAs, if [] is empty, it will be added at **each** block.\nargs.if_update_encoder = True\nargs.if_encoder_lora_layer = True\nargs.encoder_lora_layer = []\nargs.if_decoder_lora_layer = True  \n```\n- If you only want to add LoRA blocks on the mask decoder:\n```\nargs.if_decoder_lora_layer = True  \n```\n\n### Other configurations\n1. If you want to enable warmup:\n```\n# If you want to use warmup\nargs.if_warmup = True\nargs.warmup_period = 200\n```\n2. If you want to use DDP training for multiple GPUs, use \n```\npython DDP_train_xxx.py\n```\nOtherwise, use:\n```\npython SingleGPU_train_xxx.py\n```\nif the network is large and you cannot fit into one single GPU, you can use our DDP_train_xxx.py as well as split the image encoder into 2 GPUs:\n```\nargs.if_split_encoder_gpus = True\nargs.gpu_fractions = [0.5,0.5] # the fraction of image encoder on each GPU\n```\n\n### Multi-cls segmentation VS. binary segmentation\n1. if you want to do binary segmentation:\n```\n# set the output channels as 2 (background, object)\nargs.num_cls = 2\n```\n\nIf your target objects actually have multiple labels but you want to combine them as binary:\n```\n# put the dataset's parameter for 'target' as 'combine_all', for example:\nPublic_dataset(args,args.img_folder, args.mask_folder, train_img_list,phase='train',targets=['combine_all'],normalize_type='sam',if_prompt=False)\n```\n2. if you want to do multi-cls segmentation:\n```\n# set the output channels as num_of_target_objects + 1 (background, object1, object2,...)\nargs.num_cls = n+1\n\n# put the dataset's parameter for 'target' as 'multi_all', for example:\nPublic_dataset(args,args.img_folder, args.mask_folder, train_img_list,phase='train',targets=['multi_all'],normalize_type='sam',if_prompt=False)\n```\n\n3. if you actually have multiple different targets but you want to select a subset, such as one target from your mask for trianing:\n```\nTodo\n```\n\n### Example bash file for running the training\nHere is one example (train_singlegpu_demo.sh) of running the training on a demo dataset using **vit-b** with **Adapter** and updating **Mask Decoder** only.\n```\n#!/bin/bash\n\n# Set CUDA device\nexport CUDA_VISIBLE_DEVICES=\"5\"\n\n# Define variables\narch=\"vit_b\"  # Change this value as needed\nfinetune_type=\"adapter\"\ndataset_name=\"MRI-Prostate\"  # Assuming you set this if it's dynamic\ntargets='combine_all' # make it as binary segmentation 'multi_all' for multi cls segmentation\n# Construct train and validation image list paths\nimg_folder=\"./datasets\"  # Assuming this is the folder where images are stored\ntrain_img_list=\"${img_folder}/${dataset_name}/train_5shot.csv\"\nval_img_list=\"${img_folder}/${dataset_name}/val_5shot.csv\"\n\n\n# Construct the checkpoint directory argument\ndir_checkpoint=\"2D-SAM_${arch}_decoder_${finetune_type}_${dataset_name}_noprompt\"\n\n# Run the Python script\npython SingleGPU_train_finetune_noprompt.py \\\n    -if_warmup True \\\n    -finetune_type \"$finetune_type\" \\\n    -arch \"$arch\" \\\n    -if_mask_decoder_adapter True \\\n    -img_folder \"$img_folder\" \\\n    -mask_folder \"$img_folder\" \\\n    -sam_ckpt \"sam_vit_b_01ec64.pth\" \\\n    -dataset_name \"$dataset_name\" \\\n    -dir_checkpoint \"$dir_checkpoint\" \\\n    -train_img_list \"$train_img_list\" \\\n    -val_img_list \"$val_img_list\"\n```\nTo run the training, just use the command:\n```\nbash train_singlegpu_demo.sh\nor \nbash train_ddpgpu_demo.sh\n```\n\n### Visualization of the loss\nYou can visualize your training logs using tensorboard; in a terminal, just type:\n```\ntensorboard --logdir args.dir_checkpoint/log --ip 0.0.0.0\n```\nThen, open the browser to visualize the loss.\n\n\n### Additional interactive modes\nif you want to use prompt_based training, just edit the dataset into **prompt_type='point' or prompt_type='box' or prompt_type='hybrid'**, for example:\n```\ntrain_dataset = Public_dataset(args,args.img_folder, args.mask_folder, train_img_list,phase='train',targets=['all'],normalize_type='sam',prompt_type='point')\neval_dataset = Public_dataset(args,args.img_folder, args.mask_folder, val_img_list,phase='val',targets=['all'],normalize_type='sam',prompt_type='point')\n```\nAnd you need to edit the block for the prompt encoder input accordingly:\n```\nsparse_emb, dense_emb = sam_fine_tune.prompt_encoder(\n            points=points,\n            boxes=None,\n            masks=None,\n        )\n```\n## Step 3: Validation of the model\n```\nbash val_singlegpu_demo.sh\n```\n\n## Additional model inference mode and prediction visualization\nRefer to  2D_predictions_with_vis.ipynb and 3D_predictions_with_vis.ipynb.\n\n\n## b): fine-tune from task-expansive pretrained weights\nIf you want to use MedSAM as pretrained weights, please refer to [MedSAM](https://github.com/bowang-lab/MedSAM) and download their checkpoints as 'medsam_vit_b.pth'.\n\n## c): fine-tune from task-agnostic self-supervised pre-trained weights\nIn our paper, we found that training in Setup 3, which starts from self-supervised weights and then fine-tuning to one customized dataset using Parameter Efficient Learning to fine-tune both Encoder/Decoder, provides the best model.\nTo use our self-supervised pretrained weights, please refer to [SSLSAM](https://drive.google.com/drive/folders/1JAoy-Mh5QgxXsjWtQhMjOX16dN1kytLQ).\n\n## ToDOlist:\n - [x] add the branch of codes for automatic multi-cls segmentation\n - [ ] add the branch of codes for prompt-based multi-cls segmentation. output has two channels and random select one target at one time during training.\n\n\n## Acknowledgement\nThis work was supported by Duke Univeristy.\nWe built these codes based on the following:\n1. [SAM](https://github.com/facebookresearch/segment-anything)\n2. [MobileSAM](https://github.com/ChaoningZhang/MobileSAM)\n3. [MedSAM](https://github.com/bowang-lab/MedSAM)\n4. [Medical SAM Adapter](https://github.com/KidsWithTokens/Medical-SAM-Adapter)\n5. [LoRA for SAM](https://github.com/JamesQFreeman/Sam_LoRA)\n\n## Citation\nPlease cite our paper if you find our codes or paper helpful, we really appreciate it [🥹 citation, please, cry cry]:\n```bib\n@misc{gu2024build,\n      title={How to build the best medical image segmentation algorithm using foundation models: a comprehensive empirical study with Segment Anything Model}, \n      author={Hanxue Gu and Haoyu Dong and Jichen Yang and Maciej A. Mazurowski},\n      year={2024},\n      eprint={2404.09957},\n      archivePrefix={arXiv},\n      primaryClass={cs.CV}\n}\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fmazurowski-lab%2Ffinetune-SAM","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fmazurowski-lab%2Ffinetune-SAM","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fmazurowski-lab%2Ffinetune-SAM/lists"}