https://github.com/lucidrains/calm-pytorch
Implementation of CALM from the paper "LLM Augmented LLMs: Expanding Capabilities through Composition", out of Google Deepmind
https://github.com/lucidrains/calm-pytorch
artificial-intelligence attention-mechanisms cross-attention deep-learning transformers
Last synced: 3 months ago
JSON representation
Implementation of CALM from the paper "LLM Augmented LLMs: Expanding Capabilities through Composition", out of Google Deepmind
- Host: GitHub
- URL: https://github.com/lucidrains/calm-pytorch
- Owner: lucidrains
- License: mit
- Created: 2024-01-09T15:49:20.000Z (over 1 year ago)
- Default Branch: main
- Last Pushed: 2024-09-12T17:26:55.000Z (10 months ago)
- Last Synced: 2025-03-30T00:06:16.020Z (3 months ago)
- Topics: artificial-intelligence, attention-mechanisms, cross-attention, deep-learning, transformers
- Language: Python
- Homepage:
- Size: 939 KB
- Stars: 174
- Watchers: 8
- Forks: 11
- Open Issues: 3
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
## CALM - Pytorch
Implementation of CALM from the paper LLM Augmented LLMs: Expanding Capabilities through Composition, out of Google Deepmind
Can support any number of augmentation LLMs
## Install
```bash
$ pip install CALM-pytorch
```## Appreciation
- A16Z Open Source AI Grant Program and 🤗 Huggingface for the generous sponsorships, as well as my other sponsors, for affording me the independence to open source current artificial intelligence research
## Usage
ex. with `x-transformers`
```python
import torch
from x_transformers import TransformerWrapper, Decoderaugment_llm = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 512,
depth = 12,
heads = 8
)
)anchor_llm = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 512,
depth = 2,
heads = 8
)
)# import CALM wrapper
from CALM_pytorch import CALM, AugmentParams
calm = CALM(
anchor_llm,
augment_llms = AugmentParams(
model = augment_llm,
connect_every_num_layers = 4
)
)# mock input
seq = torch.randint(0, 20000, (1, 1024))
mask = torch.ones((1, 1024)).bool()
prompt = torch.randint(0, 20000, (1, 256))# forward for finetuning loss
loss = calm(
seq,
mask = mask,
prompt = prompt
)loss.backward()
# after much training, prompt the composed model
generated = calm.generate(
prompt = seq[:, :1],
seq_len = 1024
)```
To use a handy trainer class using 🤗 Accelerate, just import `FineTuner` and use as follows
```python
trainer = FineTuner(
calm = calm,
dataset = dataset, # returns a dictionary of input kwargs to calm - dict(seq: Tensor, mask: Tensor, prompt: Tensor). it can also return a Tuple, in which data_kwargs needs to be set to the correct ordered value of kwarg names
batch_size = 16,
num_train_steps = 10000,
learning_rate = 3e-4,
weight_decay = 1e-2,
warmup_steps = 1000,
checkpoint_every = 1000
)trainer()
# checkpoints of the cross attention parameters will be saved to ./checkpoints every 1000 steps
```To explore multiple augmentation LLMs, simply pass in a list for `augment_llm`
ex.
```python
calm = CALM(
anchor_llm = anchor_llm,
augment_llm = [AugmentParams(augment_llm1), AugmentParams(augment_llm2)] # pass in a list of AugmentParams wrapping model and other hparams specific to that transformer
)
```Say you want to explore different types of connectivity between anchor and augmentation model(s), just pass in the connections as a tuple of tuple integer pairs, specifying the anchor to augment layer number.
```python
calm = CALM(
anchor_llm = anchor_llm,
augment_llms = (
AugmentParams(
model = augment_llm1,
connections = (
(1, 12), # 1st layer of augment llm1 attended to by 12th layer of anchor llm
(2, 12),
(3, 12),
(4, 12),
),
),
AugmentParams(
model = augment_llm2,
connections = (
(6, 1), # 6th layer of augment llm2 attended to by 1st layer of anchor llm
(6, 2),
(12, 12),
)
)
)
)
```CALM setup with 2 specialized augmentation LLMs + a vision transformer
```python
import torch# pip install vit-pytorch x-transformers
from vit_pytorch.vit import ViT, Attention
from x_transformers import TransformerWrapper, Encoder, Decoderanchor_llm = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 16,
dim_head = 2,
depth = 12,
heads = 8
)
)augment_llm1 = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Encoder(
dim = 16,
dim_head = 2,
depth = 12,
heads = 8
)
)augment_llm2 = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Encoder(
dim = 16,
dim_head = 2,
depth = 12,
heads = 8
)
)vit = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 256,
depth = 6,
heads = 16,
mlp_dim = 2048
)# calm
from CALM_pytorch import CALM, AugmentParams, FineTuner
calm = CALM(
anchor_llm = anchor_llm,
augment_llms = (
AugmentParams(
model = augment_llm1,
mask_kwarg = 'mask'
),
AugmentParams(
model = augment_llm2,
mask_kwarg = 'mask'
),
AugmentParams(
model = vit,
input_shape = (3, 256, 256),
hidden_position = 'input',
extract_blocks_fn = lambda vit: [m for m in vit.modules() if isinstance(m, Attention)]
)
),
attn_kwargs = dict(
linear_project_context = True,
pre_rmsnorm = True,
flash = True
)
)seq = torch.randint(0, 20000, (1, 1024))
mask = torch.ones((1, 1024)).bool()prompt = (
torch.randint(0, 20000, (1, 256)),
torch.randint(0, 20000, (1, 256)),
torch.randn(1, 3, 256, 256)
)loss = calm(
seq,
mask = mask,
prompt = prompt
)loss.backward()
```## Todo
- [x] figure out how to correctly mask augment llm tokens
- [x] auto-derive model dimensions with dummy input
- [x] take care of finetuning training logic
- [x] show example of manual definitions of custom connectivity between 2+ attention networks
- [x] if anchor and augment transformer block modules are directly passed in (without extraction fn), run a dummy input through both networks and order them correctly using hooks
- [x] fix example for x-transformers, as in x-transformers, depth is actually depth x 2, taking hiddens from after attention and ff
- [x] when finely specifying hidden positions, make sure to reorder it if the transformer blocks themselves were passed in and not ordered to begin with
- [x] extend to a list of augmentation llms
- [x] full connectivity customization
- [x] custom number of augmentation layers per augmetation llm
- [x] make simple vit work
- [x] refactor so extraction fn, mask kwarg, and other related hparams are grouped together under a dictionary of {[augment_llm_name]: {augment_llm_related_hparams}} - use dataclasses
- [x] show example
- [x] take care of caching the augment hiddens when sampling. forget about anchor kv cache for now
- [x] logic for not releasing the saved output from recorder, for inference
- [x] managing cross attention block state for popping the saved output from the recorder
- [x] move the augmentation forwards into one shared method, and craft out sampling method for anchor- [ ] able to wire up with just module names
- [ ] show an example with giving the LLM ability to hear as well, using hubert or wav2vec wrappers
- [ ] handle a wrapper or function that takes in the sequence and prompt length, and auto derives the inputs to CALM
- [ ] add an option for self attention path way with memory tokens attending to hidden states of all augmentation llms, akin to what was done with Zorro## Citations
```bibtex
@inproceedings{Bansal2024LLMAL,
title = {LLM Augmented LLMs: Expanding Capabilities through Composition},
author = {Rachit Bansal and Bidisha Samanta and Siddharth Dalmia and Nitish Gupta and Shikhar Vashishth and Sriram Ganapathy and Abhishek Bapna and Prateek Jain and Partha Pratim Talukdar},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:266755751}
}
```