An open API service indexing awesome lists of open source software.

https://github.com/dropbox/hqq

Official implementation of Half-Quadratic Quantization (HQQ)
https://github.com/dropbox/hqq

llm machine-learning quantization

Last synced: about 1 month ago
JSON representation

Official implementation of Half-Quadratic Quantization (HQQ)

Awesome Lists containing this project

README

          

## Half-Quadratic Quantization (HQQ)
This repository contains the official implementation of Half-Quadratic Quantization (HQQ) presented in our articles:
* HQQ: https://dropbox.github.io/hqq_blog/
* HQQ+: https://dropbox.github.io/1bit_blog/

### What is HQQ?
HQQ is a fast and accurate model quantizer that skips the need for calibration data. Quantize the largest models, without calibration data, in just a few minutes at most 🚀.

FAQ
Why should I use HQQ instead of other quantization methods?


  • HQQ is very fast to quantize models.

  • It supports 8,4,3,2,1 bits.

  • You can use it on any model (LLMs, Vision, etc.).

  • The dequantization step is a linear operation, this means that HQQ is compatbile with various optimized CUDA/Triton kernels.

  • HQQ is compatible with peft training.

  • We try to make HQQ fully compatible `torch.compile` for faster inference and training.



What is the quality of the quantized models?

We have detailed benchmarks on both language and vision models. Please refer to our blog posts: HQQ, HQQ+.

What is the speed of the quantized models?

4-bit models with `axis=1` can use optimized inference fused kernels. Moreover, we focus on making hqq fully compatible with `torch.compile` which speeds-up both training and inference. For more details, please refer to the backend section below.

What quantization settings should I use?

You should start with `nbits=4, group_size=64, axis=1`. These settings offer a good balance between quality, vram usage and speed. If you want better results with the same vram usage, switch to `axis=0` and use the ATEN backend, but this setting is not supported for fast inference.


What does the `axis` parameter mean?

The `axis` parameter is the axis along which grouping is performed. In general `axis=0` gives better results than `axis=1`, especially at lower bits. However, the optimized inference runtime only supports `axis=1` for the moment.


What is the difference between HQQ and HQQ+?

HQQ+ is HQQ with trainable low-rank adapters to improve the quantization quality at lower bits.

### Installation
First, make sure you have a Pytorch 2 version that matches your CUDA version: https://pytorch.org/

You can install hqq via
```
#latest stable version
pip install hqq;

#Latest updates - recommended
pip install git+https://github.com/dropbox/hqq.git;

#Disable building the CUDA kernels for the aten backend
DISABLE_CUDA=1 pip install ...
```

Alternatively, clone the repo and run ```pip install .``` from this current folder.

### Basic Usage
To perform quantization with HQQ, you simply need to replace the linear layers ( ```torch.nn.Linear```) as follows:
```Python
from hqq.core.quantize import *
#Quantization settings
quant_config = BaseQuantizeConfig(nbits=4, group_size=64)

#Replace your linear layer
hqq_layer = HQQLinear(your_linear_layer, #torch.nn.Linear or None
quant_config=quant_config, #quantization configuration
compute_dtype=torch.float16, #compute dtype
device='cuda', #cuda device
initialize=True, #Use False to quantize later
del_orig=True #if True, delete the original layer
)

W_r = hqq_layer.dequantize() #dequantize()
W_q = hqq_layer.unpack(dtype=torch.uint8) #unpack
y = hqq_layer(x) #forward-pass
```

The quantization parameters are set as follows:

- ```nbits``` (int): supports 8, 4, 3, 2, 1 bits.
- ```group_size``` (int): no restrictions as long as ```weight.numel()``` is divisible by the ```group_size```.
- ```view_as_float``` (bool): if True, the quantized parameter is viewed as float instead of an int type.

### Usage with Models
#### Transformers 🤗
For usage with HF's transformers, see the example below from the documentation:
```Python
from transformers import AutoModelForCausalLM, HqqConfig

# All linear layers will use the same quantization config
quant_config = HqqConfig(nbits=4, group_size=64)

# Load and quantize
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="cuda",
quantization_config=quant_config
)
```
You can save/load quantized models as regular transformers models via `save_pretrained` / `from_pretrained`.

#### HQQ Lib
You can also utilize the HQQ library to quantize transformers models:
```Python
#Load the model on CPU
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=compute_dtype)

#Quantize
from hqq.models.hf.base import AutoHQQHFModel
quant_config = BaseQuantizeConfig(nbits=4, group_size=64)
AutoHQQHFModel.quantize_model(model, quant_config=quant_config, compute_dtype=compute_dtype, device=device)
```
You can save/load quantized models as follows:
```Python
from hqq.models.hf.base import AutoHQQHFModel

#Save: Make sure to save the model BEFORE any patching
AutoHQQHFModel.save_quantized(model, save_dir)

#Save as safetensors (to be load via transformers or vllm)
AutoHQQHFModel.save_to_safetensors(model, save_dir)

#Load
model = AutoHQQHFModel.from_quantized(save_dir)
```

❗ Note that models saved via the hqq lib are not compatible with `.from_pretrained()`

### Backends
#### Native Backends
The following native dequantization backends can be used by the `HQQLinear` module:
```Python
HQQLinear.set_backend(HQQBackend.PYTORCH) #Pytorch backend - Default
HQQLinear.set_backend(HQQBackend.PYTORCH_COMPILE) #Compiled Pytorch
HQQLinear.set_backend(HQQBackend.ATEN) #Aten/CUDA backend - only axis=0 supported
```
❗ Note that ```HQQBackend.ATEN``` only supports `axis=0`.

#### Optimized Inference
We support external backends for faster inference with fused kernels. You can enable one of the backends after the model was quantized as follows:
```Python
from hqq.utils.patching import prepare_for_inference

#Pytorch backend that makes the model compatible with fullgraph torch.compile: works with any settings
#prepare_for_inference(model)

#Gemlite backend: nbits=4/2/1, compute_dtype=float16, axis=1
prepare_for_inference(model, backend="gemlite")

#Torchao's tiny_gemm backend (fast for batch-size<4): nbits=4, compute_dtype=bfloat16, axis=1
#prepare_for_inference(model, backend="torchao_int4")
```
Note that these backends only work with `axis=1`. Additional restrictions apply regarding the group-size values depending on the backend. You should expect ~158 tokens/sec with a Llama3-8B 4-bit quantized model on a 4090 RTX.

When a quantization config is not supported by the specified inference backend, hqq will fallback to the native backend.

### Custom Quantization Configurations ⚙️
You can set up various quantization configurations for different layers by specifying the settings for each layer name:
#### Transformers 🤗
```Python
# Each linear layer with the same tag will use a dedicated quantization config
q4_config = {'nbits':4, 'group_size':64}
q3_config = {'nbits':3, 'group_size':32}

quant_config = HqqConfig(dynamic_config={
'self_attn.q_proj':q4_config,
'self_attn.k_proj':q4_config,
'self_attn.v_proj':q4_config,
'self_attn.o_proj':q4_config,

'mlp.gate_proj':q3_config,
'mlp.up_proj' :q3_config,
'mlp.down_proj':q3_config,
})
```
#### HQQ lib
```Python
from hqq.core.quantize import *
q4_config = BaseQuantizeConfig(nbits=4, group_size=64)
q3_config = BaseQuantizeConfig(nbits=3, group_size=32)

quant_config = {'self_attn.q_proj':q4_config,
'self_attn.k_proj':q4_config,
'self_attn.v_proj':q4_config,
'self_attn.o_proj':q4_config,

'mlp.gate_proj':q3_config,
'mlp.up_proj' :q3_config,
'mlp.down_proj':q3_config,
}
```

### VLLM
You can use HQQ in vllm. Make sure to install GemLite before using the backend.

```Python
#Or you can quantize on-the-fly
from hqq.utils.vllm import set_vllm_onthefly_hqq_quant
skip_modules = ['lm_head', 'visual', 'vision']

#Select one of the following modes:

#INT/FP format
set_vllm_onthefly_hqq_quant(weight_bits=8, group_size=None, quant_mode='int8_weightonly', skip_modules=skip_modules) #A16W8 - INT8 weight only
set_vllm_onthefly_hqq_quant(weight_bits=4, group_size=128, quant_mode='int4_weightonly', skip_modules=skip_modules) #A16W4 - HQQ weight only
set_vllm_onthefly_hqq_quant(weight_bits=8, quant_mode='int8_dynamic', skip_modules=skip_modules) #A8W8 - INT8 x INT8 dynamic
set_vllm_onthefly_hqq_quant(weight_bits=8, quant_mode='fp8_dynamic', skip_modules=skip_modules) #A8W8 - FP8 x FP8 dynamic

#MXFP format
set_vllm_onthefly_hqq_quant(weight_bits=8, group_size=None, quant_mode='mxfp8_dynamic', skip_modules=skip_modules) #A8W8 - MXFP8 x MXPF8 - post_scale=True
set_vllm_onthefly_hqq_quant(weight_bits=8, group_size=32, quant_mode='mxfp8_dynamic', skip_modules=skip_modules) #A8W8 - MXFP8 x MXPF8- post_scale=False
set_vllm_onthefly_hqq_quant(weight_bits=4, quant_mode='mxfp4_weightonly', skip_modules=skip_modules) #A16W4 - MXFP4 weight-only
set_vllm_onthefly_hqq_quant(weight_bits=4, quant_mode='mxfp8_dynamic', skip_modules=skip_modules) #A8W4 - MXFP8 x MXFP4 dynamic
set_vllm_onthefly_hqq_quant(weight_bits=4, quant_mode='mxfp4_dynamic', skip_modules=skip_modules) #A4W4 - MXPF4 x MXPF4 dynamic
set_vllm_onthefly_hqq_quant(weight_bits=4, quant_mode='nvfp4_dynamic', skip_modules=skip_modules) #A4W4 - NVFP4 x NVFP4 dynamic

llm = LLM(model="meta-llama/Llama-3.2-3B-Instruct", max_model_len=4096, gpu_memory_utilization=0.80, dtype=torch.float16)
```

### Peft Training
Peft training is directly supported in the HuggingFace's peft library. If you still want to use hqq-lib's peft utilities, here's how:

```Python
#First, quantize/load a quantized HQQ model the
from hqq.core.peft import PeftUtils

base_lora_params = {'lora_type':'default', 'r':32, 'lora_alpha':64, 'dropout':0.05, 'train_dtype':torch.float32}
lora_params = {'self_attn.q_proj': base_lora_params,
'self_attn.k_proj': base_lora_params,
'self_attn.v_proj': base_lora_params,
'self_attn.o_proj': base_lora_params,
'mlp.gate_proj' : None,
'mlp.up_proj' : None,
'mlp.down_proj' : None}

#Add LoRA to linear/HQQ modules
PeftUtils.add_lora(model, lora_params)

#Optional: set your backend
HQQLinear.set_backend(HQQBackend.ATEN if axis==0 else HQQBackend.PYTORCH_COMPILE)

#Train ....

#Convert LoRA weights to the same model dtype for faster inference
model.eval()
PeftUtils.cast_lora_weights(model, dtype=compute_dtype)

#Save LoRA weights
PeftUtils.save_lora_weights(model, filename)

#Load LoRA weights: automatically calls add_lora
PeftUtils.load_lora_weights(model, filename)
```

We provide a complete example to train a model with HQQ/LoRA that you can find in ```examples/hqq_plus.py```.

If you want to use muti-gpu training via FSDP, check out this awesome repo by Answer.AI: https://github.com/AnswerDotAI/fsdp_qlora

### Examples
We provide a variety of examples demonstrating model quantization across different backends within the ```examples``` directory.

### Citation 📜
```
@misc{badri2023hqq,
title = {Half-Quadratic Quantization of Large Machine Learning Models},
url = {https://dropbox.github.io/hqq_blog/},
author = {Hicham Badri and Appu Shaji},
month = {November},
year = {2023}
```