https://github.com/kyegomez/flashattention20
Get down and dirty with FlashAttention2.0 in pytorch, plug in and play no complex CUDA kernels
https://github.com/kyegomez/flashattention20
Last synced: 5 months ago
JSON representation
Get down and dirty with FlashAttention2.0 in pytorch, plug in and play no complex CUDA kernels
- Host: GitHub
- URL: https://github.com/kyegomez/flashattention20
- Owner: kyegomez
- License: mit
- Created: 2023-07-19T23:36:48.000Z (about 2 years ago)
- Default Branch: main
- Last Pushed: 2023-07-31T21:34:00.000Z (about 2 years ago)
- Last Synced: 2025-04-19T20:17:08.414Z (6 months ago)
- Language: Python
- Size: 22.5 KB
- Stars: 102
- Watchers: 2
- Forks: 6
- Open Issues: 2
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# FlashAttention2.0: A PyTorch Implementation
FlashAttention is a PyTorch implementation of the Flash Attention mechanism, a memory-efficient and highly parallelizable attention mechanism. This repository provides the code for the Flash Attention module and includes options for parallelization and mixed precision training.
## Installation
To install FlashAttention, you can clone this repository using git:
```bash
git clone https://github.com/kyegomez/FlashAttention2.0.git
cd FlashAttention2.0
```Then, you can install the required packages using pip:
```bash
pip install -r requirements.txt
```## Usage
Here is a basic example of how to use the FlashAttention module:
```python
import torch
from attention import FlashAttention# Initialize a FlashAttention module
attention = FlashAttention(dim=512, heads=8, dim_head=64)# Create some random data
x = torch.randn(1, 1000, 512)# Apply the attention module
out = attention(x)print(out.shape) # Outputs: torch.Size([1, 1000, 512])
```You can also enable parallelization and mixed precision training by setting the `parallel` and `mixed_precision` parameters to `True`:
```python
# Initialize a FlashAttention module with parallelization and mixed precision
attention = FlashAttention(dim=512, heads=8, dim_head=64, parallel=True, mixed_precision=True)# The rest of the code is the same as before
```## Tests
We have an extensive testing suite in `test.py` run that for more.
Here are some tests to verify the correctness of the forward and backward passes, run `test.py````python
import torch
from flashattention import FlashAttentiondef test_forward():
attention = FlashAttention(dim=512, heads=8, dim_head=64)
x = torch.randn(1, 1000, 512)
out = attention(x)
assert out.shape == (1, 1000, 512), f'Unexpected output shape: {out.shape}'def test_backward():
attention = FlashAttention(dim=512, heads=8, dim_head=64)
x = torch.randn(1, 1000, 512, requires_grad=True)
out = attention(x)
out.sum().backward()
assert x.grad is not None, 'No gradient computed'test_forward()
test_backward()
```These tests check that the output of the forward pass has the correct shape and that the backward pass correctly computes gradients.
# Contributing
We welcome contributions to the FlashAttention project! Whether you're interested in improving the code, optimizing the implementation, or adding new features, there are many ways to make a valuable contribution.
## How to Contribute
1. **Fork the repository**: Click the 'Fork' button at the top-right of this page to create your own copy of the repository.
2. **Clone your fork**: Clone your forked repository to your local machine. You can do this with the command `git clone https://github.com/yourusername/flashattention.git`.
3. **Create a new branch**: Create a new branch for your changes with the command `git checkout -b your-branch-name`.
4. **Make your changes**: Make your changes to the code. Please try to follow the existing coding style.
5. **Commit your changes**: Commit your changes with the command `git commit -m "Your commit message"`.
6. **Push your changes**: Push your changes to your forked repository with the command `git push origin your-branch-name`.
7. **Create a pull request**: Go to the [original FlashAttention repository](https://github.com/yourusername/flashattention) and click the 'New pull request' button. Select your forked repository and the branch you created, then click 'Create pull request'.
## Potential Optimizations
There are several areas where the FlashAttention implementation could potentially be optimized:
- **Memory usage**: The current implementation is already quite memory-efficient, but there may be ways to further reduce memory usage.
- **Speed**: The speed of the forward and backward passes could potentially be improved. This could involve optimizing the existing code or implementing new, faster algorithms.
- **Scalability**: The current implementation scales well to large input sizes, but there may be ways to improve scalability further.
- **Precision**: The implementation currently supports mixed precision training, but there may be ways to improve the precision of the computations.
## Metrics
When optimizing the FlashAttention implementation, we should aim to minimize the following metrics:
- **Memory usage**: The amount of memory used by the implementation.
- **Execution time**: The time taken to execute the forward and backward passes.
- **Error rate**: The rate of errors in the output of the attention module.
We look forward to your contributions!
# Code:
```
import math
import torch
from functools import partial
from torch import nn, einsum
from torch.autograd.function import Functionfrom einops import rearrange
from torch.jit import fork, wait
from torch.cuda.amp import autocast, GradScaler
from torch.nn import DataParallel
# constantsEPSILON = 1e-10
# helper functions
def exists(val):
return val is not Nonedef default(val, d):
return val if exists(val) else d# flash attention forwards and backwards
# flash attention v1 - https://arxiv.org/abs/2205.14135
# flash attention v2 - https://tridao.me/publications/flash2/flash2.pdfclass FlashAttentionFunction(Function):
@staticmethod
@torch.no_grad()
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
""" Algorithm 1 in the v2 paper """device = q.device
max_neg_value = -torch.finfo(q.dtype).max
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)o = torch.zeros_like(q)
all_row_sums = torch.zeros((*q.shape[:-1], 1), device = device)
all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, device = device)scale = (q.shape[-1] ** -0.5)
num_row_tiles = math.ceil(q.shape[-2] / q_bucket_size)
num_col_tiles = math.ceil(k.shape[-2] / k_bucket_size)if exists(mask) and mask.ndim == 2:
mask = rearrange(mask, 'b n -> b 1 1 n')if not exists(mask):
col_masks = (None,) * num_col_tiles
mask = (col_masks,) * num_row_tiles
else:
mask = ((mask,) * num_row_tiles) if mask.shape[-2] == 1 else mask.split(q_bucket_size, dim = -2)
mask = tuple(((row_mask,) * num_col_tiles) if row_mask.shape[-1] == 1 else row_mask.split(k_bucket_size, dim = -1) for row_mask in mask)row_splits = zip(
q.split(q_bucket_size, dim = -2),
o.split(q_bucket_size, dim = -2),
mask,
all_row_sums.split(q_bucket_size, dim = -2),
all_row_maxes.split(q_bucket_size, dim = -2),
)for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
q_start_index = ind * q_bucket_size - qk_len_diffcol_splits = zip(
k.split(k_bucket_size, dim = -2),
v.split(k_bucket_size, dim = -2),
row_mask
)for k_ind, (kc, vc, col_mask) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_sizeattn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
if exists(col_mask):
attn_weights.masked_fill_(~col_mask, max_neg_value)if causal and q_start_index < (k_start_index + k_bucket_size - 1):
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
attn_weights.masked_fill_(causal_mask, max_neg_value)block_row_maxes = attn_weights.amax(dim = -1, keepdims = True)
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)exp_weights = torch.exp(attn_weights - new_row_maxes)
if exists(col_mask):
exp_weights.masked_fill_(~col_mask, 0.)block_row_sums = exp_weights.sum(dim = -1, keepdims = True).clamp(min = EPSILON)
exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
new_row_sums = exp_row_max_diff * row_sums + block_row_sums
oc.mul_(exp_row_max_diff).add_(exp_values)
row_maxes.copy_(new_row_maxes)
row_sums.copy_(new_row_sums)oc.div_(row_sums)
lse = all_row_sums.log() + all_row_maxes
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
ctx.save_for_backward(q, k, v, o, lse)return o
@staticmethod
@torch.no_grad()
def backward(ctx, do):
""" Algorithm 2 in the v2 paper """causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
q, k, v, o, lse = ctx.saved_tensorsdevice = q.device
max_neg_value = -torch.finfo(q.dtype).max
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)dq = torch.zeros_like(q)
dk = torch.zeros_like(k)
dv = torch.zeros_like(v)row_splits = zip(
q.split(q_bucket_size, dim = -2),
o.split(q_bucket_size, dim = -2),
do.split(q_bucket_size, dim = -2),
mask,
lse.split(q_bucket_size, dim = -2),
dq.split(q_bucket_size, dim = -2)
)for ind, (qc, oc, doc, row_mask, lsec, dqc) in enumerate(row_splits):
q_start_index = ind * q_bucket_size - qk_len_diffcol_splits = zip(
k.split(k_bucket_size, dim = -2),
v.split(k_bucket_size, dim = -2),
dk.split(k_bucket_size, dim = -2),
dv.split(k_bucket_size, dim = -2),
row_mask
)for k_ind, (kc, vc, dkc, dvc, col_mask) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_sizeattn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
attn_weights.masked_fill_(causal_mask, max_neg_value)p = torch.exp(attn_weights - lsec)
if exists(col_mask):
p.masked_fill_(~col_mask, 0.)dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
dp = einsum('... i d, ... j d -> ... i j', doc, vc)D = (doc * oc).sum(dim = -1, keepdims = True)
ds = p * scale * (dp - D)dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)dqc.add_(dq_chunk)
dkc.add_(dk_chunk)
dvc.add_(dv_chunk)return dq, dk, dv, None, None, None, None
# main class
# just flash attention in plain pytorch
# it will be way slower than implementing it in CUDA
# for tinkering and educational purposesclass FlashAttention(nn.Module):
def __init__(
self,
*,
dim,
heads = 8,
dim_head = 64,
causal = False,
q_bucket_size = 512,
k_bucket_size = 1024,
parallel = False,
mixed_precision = False
):
super().__init__()
self.heads = heads
self.causal = causal
self.parallel = parallel
self.mixed_precision = mixed_precisioninner_dim = heads * dim_head
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)# memory efficient attention related parameters
# can be overriden on forward
self.q_bucket_size = q_bucket_size
self.k_bucket_size = k_bucket_sizeif self.parallel:
self.model = DataParallel(self)
if self.mixed_precision:
self.scaler = GradScaler()def forward(
self,
x,
context = None,
mask = None,
q_bucket_size = None,
k_bucket_size = None,
):
q_bucket_size = default(q_bucket_size, self.q_bucket_size)
k_bucket_size = default(k_bucket_size, self.k_bucket_size)h = self.heads
context = default(context, x)q = self.to_q(x)
k, v = self.to_kv(context).chunk(2, dim=-1)q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
if self.parallel:
# Split the input data into chunks and move each chunk to the correct GPU
num_gpus = torch.cuda.device_count()
x_chunks = x.split(x.size(0) // num_gpus)
x_chunks = [chunk.to(f'cuda:{i}') for i, chunk in enumerate(x_chunks)]
q = x_chunksif self.mixed_precision:
# Use autocast to allow operations to run in lower precision
with autocast():
out = FlashAttentionFunction.apply(q, k, v, mask, self.causal, q_bucket_size, k_bucket_size)
else:
out = FlashAttentionFunction.apply(q, k, v, mask, self.causal, q_bucket_size, k_bucket_size)out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
```