An open API service indexing awesome lists of open source software.

https://github.com/2u1/pixtral-finetune

An open-source implementaion for fine-tuning Pixtral by MistralAI.
https://github.com/2u1/pixtral-finetune

chatbot mistral multimodal pixtral vision-language-model

Last synced: about 2 months ago
JSON representation

An open-source implementaion for fine-tuning Pixtral by MistralAI.

Awesome Lists containing this project

README

        

# Fine-tuning Pixtral

This repository contains a script for training Trnasformers compatible [Pixtral-12b](https://huggingface.co/mistral-community/pixtral-12b).

However the model only supports **batch size=1**. So it could take a long time to fine tune.

## Other projects

**[[Phi3-Vision Finetuning]](https://github.com/2U1/Phi3-Vision-Finetune)**

**[[Llama3.2-Vision Finetuning]](https://github.com/2U1/Llama3.2-Vision-Ft)**

**[[Qwen2-VL Finetuning]](https://github.com/2U1/Qwen2-VL-Finetune)**

**[[Molmo Finetune]](https://github.com/2U1/Molmo-Finetune)**

**[[SmolVLM Finetune]](https://github.com/2U1/SmolVLM-Finetune)**

## Update

- [2025/01/24] Add option for using DoRA.
- [2025/01/24] Fix error in LoRA training.
- [2025/01/11] Updated 8-bit training using ms_amp fp8 with opt_level O3.

## Table of Contents

- [Fine-tuning Pixtral](#fine-tuning-pixtral)
- [Other projects](#other-projects)
- [Update](#update)
- [Table of Contents](#table-of-contents)
- [Supported Features](#supported-features)
- [Installation](#installation)
- [Using `environment.yaml`](#using-environmentyaml)
- [Dataset Preparation](#dataset-preparation)
- [Training](#training)
- [Full Finetuning](#full-finetuning)
- [Full Finetuning with 8-bit](#full-finetuning-with-8-bit)
- [Finetune with LoRA](#finetune-with-lora)
- [Train with video dataset](#train-with-video-dataset)
- [Merge LoRA Weights](#merge-lora-weights)
- [Issue for libcudnn error](#issue-for-libcudnn-error)
- [TODO](#todo)
- [Known Issues](#known-issues)
- [License](#license)
- [Citation](#citation)
- [Acknowledgement](#acknowledgement)

## Supported Features

- Deepspeed
- LoRA/QLoRA
- Full-finetuning
- Enable finetuning `vision_model` while using LoRA.
- Disable/enable Flash Attention 2
- Multi-image and video training
- Training optimized with liger kernel

## Installation

Install the required packages using `environment.yaml`.

### Using `environment.yaml`

```bash
conda env create -f environment.yaml
conda activate pixtral
pip install flash-attn==2.5.8 --no-build-isolation
```

**Note:** You should install flash-attn after installing the other packages.

## Dataset Preparation

The script requires a dataset formatted according to the LLaVA specification. The dataset should be a JSON file where each entry contains information about conversations and images. Ensure that the image paths in the dataset match the provided `--image_folder`.

**When using a multi-image dataset, the image tokens should all be ``, and the image file names should have been in a list.**
**Please see the example below and follow format your data.**

Example for single image dataset

```json
[
{
"id": "000000033471",
"image": "000000033471.jpg",
"conversations": [
{
"from": "human",
"value": "\nWhat are the colors of the bus in the image?"
},
{
"from": "gpt",
"value": "The bus in the image is white and red."
},
{
"from": "human",
"value": "What feature can be seen on the back of the bus?"
},
{
"from": "gpt",
"value": "The back of the bus features an advertisement."
},
{
"from": "human",
"value": "Is the bus driving down the street or pulled off to the side?"
},
{
"from": "gpt",
"value": "The bus is driving down the street, which is crowded with people and other vehicles."
}
]
}
...
]
```

Example for multi image dataset

```json
[
{
"id": "000000033471",
"image": ["000000033471.jpg", "000000033472.jpg"],
"conversations": [
{
"from": "human",
"value": "\n\nIs the perspective of the camera differnt?"
},
{
"from": "gpt",
"value": "Yes, It the perspective of the camera is different."
}
]
}
...
]
```

Example for video dataset

```json
[
{
"id": "sample1",
"video": "sample1.mp4",
"conversations": [
{
"from": "human",
"value": "\nWhat is going on in this video?"
},
{
"from": "gpt",
"value": "A man is walking down the road."
}
]
}
...
]
```

**Note:** Officially pixtral dosen't support the video, but it supports multi-image so you could just use the video as a sequential of frames.

## Training

To run the training script, use the following command:

### Full Finetuning

```bash
bash scripts/finetune.sh
```

### Full Finetuning with 8-bit

```bash
bash scripts/finetune_8bit.sh
```

**You need to install [ms-amp](https://github.com/Azure/MS-AMP) to use this script.**

This script will finetune the model with fp8 model dtype. If you run out of vram, you could use this.

You could use fp8 training with offloading togegher.

### Finetune with LoRA

If you want to train only the language model with LoRA and perform full training for the vision model:

```bash
bash scripts/finetune_lora.sh
```

If you want to train both the language model and the vision model with LoRA:

```bash
bash scripts/finetune_lora_vision.sh
```

**IMPORTANT:** If you want to tune the `embed_token` with LoRA, You need to tune `lm_head` together.

Training arguments

- `--deepspeed` (str): Path to DeepSpeed config file (default: "scripts/zero2.json").
- `--data_path` (str): Path to the LLaVA formatted training data (a JSON file). **(Required)**
- `--image_folder` (str): Path to the images folder as referenced in the LLaVA formatted training data. **(Required)**
- `--model_id` (str): Path to the Pixtral model. **(Required)**
- `--output_dir` (str): Output directory for model checkpoints
- `--num_train_epochs` (int): Number of training epochs (default: 1).
- `--per_device_train_batch_size` (int): Training batch size per GPU per forwarding step.
- `--gradient_accumulation_steps` (int): Gradient accumulation steps (default: 4).
- `--freeze_vision_tower` (bool): Option to freeze vision_model (default: False).
- `--freeze_llm` (bool): Option to freeze LLM (default: False).
- `--tune_merger` (bool): Option to tune projector (default: True).
- `--num_lora_modules` (int): Number of target modules to add LoRA (-1 means all layers).
- `--vision_lr` (float): Learning rate for vision_model.
- `--merger_lr` (float): Learning rate for merger(projector).
- `--learning_rate` (float): Learning rate for language module.
- `--max_num_frames` (int): Maxmimum frames for video dataset (default: 10)
- `--bf16` (bool): Option for using bfloat16.
- `--fp16` (bool): Option for using fp16.
- `--min_pixels` (int): Option for minimum input tokens.
- `--max_pixles` (int): OPtion for maximum maxmimum tokens.
- `--lora_enable` (bool): Option for enabling LoRA (default: False)
- `--vision_lora` (bool): Option for including vision_tower to the LoRA module. The `lora_enable` should be `True` to use this option. (default: False)
- `--use_dora` (bool): Option for using DoRA instead of LoRA. The `lora_enable` should be `True` to use this option. (default: False)
- `--lora_namespan_exclude` (str): Exclude modules with namespans to add LoRA.
- `--max_seq_length` (int): Maximum sequence length (default: 32K).
- `--bits` (int): Quantization bits (default: 16).
- `--disable_flash_attn2` (bool): Disable Flash Attention 2.
- `--report_to` (str): Reporting tool (choices: 'tensorboard', 'wandb', 'none') (default: 'tensorboard').
- `--logging_dir` (str): Logging directory (default: "./tf-logs").
- `--lora_rank` (int): LoRA rank (default: 128).
- `--lora_alpha` (int): LoRA alpha (default: 256).
- `--lora_dropout` (float): LoRA dropout (default: 0.05).
- `--logging_steps` (int): Logging steps (default: 1).
- `--dataloader_num_workers` (int): Number of data loader workers (default: 4).

**Note:** The learning rate of `vision_model` should be 10x ~ 5x smaller than the `language_model`.

### Train with video dataset

You can train the model using a video dataset. However, officially pixtral dosen't support video. So this code processes videos as a sequence of images, so you’ll need to select specific frames and treat them as multiple images for training. You can set LoRA configs and use for LoRA too.

```bash
bash scripts/finetune_video.sh
```

**Note**: You should adjust max_num_frames based on the available VRAM.

If you run out of vram, you can use [zero3_offload](./scripts/zero3_offload.json) instead of [zero3](./scripts/zero3_offload.json). However, using zero3 is preferred.

#### Merge LoRA Weights

```
bash scripts/merge_lora.sh
```

**Note:** Remember to replace the paths in `finetune.sh` or `finetune_lora.sh` with your specific paths. (Also in `merge_lora.sh` when using LoRA.)

#### Issue for libcudnn error

```
Could not load library libcudnn_cnn_train.so.8. Error: /usr/local/cuda-12.1/lib/libcudnn_cnn_train.so.8: undefined symbol: _ZN5cudnn3cnn34layerNormFwd_execute_internal_implERKNS_7backend11VariantPackEP11CUstream_stRNS0_18LayerNormFwdParamsERKNS1_20NormForwardOperationEmb, version libcudnn_cnn_infer.so.8
```

You could run `unset LD_LIBRARY_PATH` for this error.
You could see this [issue](https://github.com/andimarafioti/florence2-finetuning/issues/2)

## TODO

- [ ] Support batch size > 1

## Known Issues

- [libcudnn issue](#issue-for-libcudnn-error)

## License

This project is licensed under the Apache-2.0 License. See the [LICENSE](LICENSE) file for details.

## Citation

If you find this repository useful in your project, please consider giving a :star: and citing:

```bibtex
@misc{Pixtral-Finetuning,
author = {Yuwon Lee},
title = {Pixtral-Finetune},
year = {2024},
publisher = {GitHub},
url = {https://github.com/2U1/Pixtral-Finetune}
}
```

## Acknowledgement

This project is based on

- [LLaVA-NeXT](https://github.com/LLaVA-VL/LLaVA-NeXT): An amazing open-source project of LMM.
- [Pixtral-12B](https://huggingface.co/mistral-community/pixtral-12b): Transformer compatible version of pixtral-12b