Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/fkodom/lora-pytorch
Simple but robust implementation of LoRA for PyTorch. Compatible with NLP, CV, and other model types. Strongly typed and tested.
https://github.com/fkodom/lora-pytorch
Last synced: about 2 months ago
JSON representation
Simple but robust implementation of LoRA for PyTorch. Compatible with NLP, CV, and other model types. Strongly typed and tested.
- Host: GitHub
- URL: https://github.com/fkodom/lora-pytorch
- Owner: fkodom
- License: mit
- Created: 2023-07-07T13:36:14.000Z (over 1 year ago)
- Default Branch: main
- Last Pushed: 2024-05-13T01:16:32.000Z (7 months ago)
- Last Synced: 2024-10-23T03:17:27.669Z (2 months ago)
- Language: Python
- Homepage:
- Size: 56.6 KB
- Stars: 21
- Watchers: 3
- Forks: 3
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# lora-pytorch
A simple but robust implementation of [LoRA (Low-Rank Adaptation)](https://arxiv.org/pdf/2106.09685.pdf) for PyTorch, which depends only on PyTorch itself! No dependence on `transformers` or other packages.
* Compatible with LLMs, CNNs, MLPs, and other model types ✔️
* Strongly typed ✔️
* Fully tested ✔️## Install
PyPI:
```bash
pip install lora-pytorch
```From source:
```bash
pip install "lora-pytorch @ git+ssh://[email protected]/fkodom/lora-pytorch.git"
```For contributors:
```bash
# Clone repository
gh repo clone fkodom/lora-pytorch
# Install all dev dependencies (tests etc.)
cd lora-pytorch
pip install -e ".[all]"
# Setup pre-commit hooks
pre-commit install
```## Usage
```python
import torch
from lora_pytorch import LoRA
from torchvision.models import resnet18, ResNet# Wrap your model with LoRA
model = resnet18()
lora_model = LoRA.from_module(model, rank=5)print(lora_model)
# LoRA(
# (module): ResNet(
# (conv1): LoRA(
# (module): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# (lora_module): Conv2dLoRAModule(
# (in_conv): Conv2d(3, 5, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# (out_conv): Conv2d(5, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
# (dropout): Dropout(p=0.0, inplace=False)
# )
# )
# (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (relu): ReLU(inplace=True)
# ...# Train or predict as usual.
x = torch.randn(1, 3, 224, 224)
y = lora_model(x)
# compute loss, backprop, etc...# Merge LoRA weights into the original model.
new_model = lora_model.merge_lora(inplace=False) # default: inplace=False# NOTE: new_model has the same type as the original model! Inference is just as
# fast as in the original model.
assert isinstance(new_model, ResNet)
```### Advanced Usage
Enable or disable `LoRA` as needed. (e.g. to access the original model)
**NOTE**: `LoRA` will *not* track gradients from the original model.
```python
# Disable
lora_model.disable_lora()
y = lora_model(x)
print(y.requires_grad)
# False# Re-enable
lora_model.enable_lora()
y = lora_model(x)
print(y.requires_grad)
# True
```Remove `LoRA` from the model.
**NOTE**: The original model weights will be unchanged.
```python
# Remove
original_model = lora_model.remove_lora(inplace=False) # default: inplace=False
assert isinstance(original_model, ResNet)
```## Supported Layers
Layer | Supported
--- | ---
`nn.Linear` | ✅
`nn.Embedding` | ✅
`nn.MultiheadAttention` | ✅
`nn.TransformerEncoder` | ✅
`nn.TransformerEncoderLayer` | ✅
`nn.TransformerDecoder` | ✅
`nn.TransformerDecoderLayer` | ✅
`nn.Transformer` | ✅
`nn.Conv1d` | ✅
`nn.Conv2d` | ✅
`nn.Conv3d` | ✅
`nn.ConvTranspose1d` | ❌
`nn.ConvTranspose2d` | ❌
`nn.ConvTranspose3d` | ❌**NOTE**: Activation, normalization, dropout, etc. layers are not affected by `LoRA`. Those are not listed here, but you shouldn't have any problems using them.
## TODO
* Add support for `ConvTranspose` layers.
* Experiments with large, pretrained models
* Specifically, models that are not covered by LoRA in [huggingface/transformers](https://github.com/huggingface/transformers).
* Lots of CV examples: ResNet, ViT, DETR, UNET, DeepLab, etc.