https://github.com/erfanzar/jax-flash-attn2
A flexible and efficient implementation of Flash Attention 2.0 for JAX, supporting multiple backends (GPU/TPU/CPU) and platforms (Triton/Pallas/JAX).
https://github.com/erfanzar/jax-flash-attn2
flash-attention flash-attention-2 jax pallas
Last synced: 2 months ago
JSON representation
A flexible and efficient implementation of Flash Attention 2.0 for JAX, supporting multiple backends (GPU/TPU/CPU) and platforms (Triton/Pallas/JAX).
- Host: GitHub
- URL: https://github.com/erfanzar/jax-flash-attn2
- Owner: erfanzar
- License: apache-2.0
- Created: 2024-10-23T18:08:50.000Z (12 months ago)
- Default Branch: main
- Last Pushed: 2025-03-04T10:55:23.000Z (7 months ago)
- Last Synced: 2025-04-06T06:06:07.223Z (6 months ago)
- Topics: flash-attention, flash-attention-2, jax, pallas
- Language: Python
- Homepage:
- Size: 6.19 MB
- Stars: 23
- Watchers: 1
- Forks: 0
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# JAX-Flash-Attention2
A flexible and efficient implementation of Flash Attention 2.0 for JAX, supporting multiple backends (GPU/TPU/CPU) and platforms (Triton/Pallas/JAX).
## Installation
```bash
pip install jax-flash-attn2
```## Basic Usage
```python
import jax
import jax.numpy as jnp
import jax_flash_attn2 as jfa# Initialize the FlashAttention module with desired configuration
flash_attention = jfa.FlashAttention(
jfa.AttentionConfig(
platform=jfa.Platform.TRITON, # Options: TRITON, PALLAS, JAX
backend=jfa.Backend.GPU, # Options: GPU, TPU, CPU
)
)# Create sample inputs
batch_size, num_heads, seq_len, head_dim = 2, 4, 512, 64
query = jax.random.normal(jax.random.PRNGKey(0), (batch_size, num_heads * 4, seq_len, head_dim), "f2")
key = jax.random.normal(jax.random.PRNGKey(1), (batch_size, num_heads, seq_len, head_dim), "f2")
value = jax.random.normal(jax.random.PRNGKey(2), (batch_size, num_heads, seq_len, head_dim), "f2")# Compute attention
output = flash_attention(
query=query,
key=key,
value=value,
causal=True # Enable causal masking for decoder-only models
)# output shape: (batch_size, num_heads, seq_len, head_dim)
```## Advanced Usage
### With Attention Mask
```python
# Create an attention mask (1 = attend, 0 = mask)
attention_mask = jnp.ones((batch_size, 1, seq_len, seq_len)) # Allow full attention
# For example, mask the first 100 tokens from attending to the last 100 tokens
attention_mask = attention_mask.at[:, :, :100, -100:].set(0)output = flash_attention(
query=query,
key=key,
value=value,
attention_mask=attention_mask,
causal=False # Using explicit mask instead of causal
)
```### With Attention Bias
```python
# Create an attention bias
bias = jnp.zeros((batch_size, 1, seq_len, seq_len))
# Add position-dependent bias
for i in range(seq_len):
for j in range(seq_len):
bias = bias.at[:, :, i, j].set(1.0 / (1.0 + abs(i - j)))output = flash_attention(
query=query,
key=key,
value=value,
bias=bias
)
```### With Dropout
```python
output = flash_attention(
query=query,
key=key,
value=value,
dropout_prob=0.1,
dropout_seed=42,
causal=True
)
```## Flax Modules with JFA2
Here's an example of integrating jax-flash-attn2 within a Transformer model implemented in Flax:
```python
import typing as tp
from functools import partialimport chex
import flax.nnx as nn
import jax
import jax.numpy as jnpimport jax_flash_attn2 as jfa
class JFAttention2(nn.Module):
def __init__(
self,
hidden_size: int,
head_dim: int,
num_attention_heads: int,
num_key_value_heads: int,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs = None,
):
if rngs is None:
rngs = nn.Rngs(0)
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
self.rngs = rngsself.hidden_size = hidden_size
self.head_dim = head_dim
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_headsself.num_key_value_groups = num_attention_heads // num_key_value_heads
if self.num_key_value_groups == 1:
assert num_attention_heads == num_key_value_headslinear_class = partial(
nn.Linear,
dtype=dtype,
param_dtype=param_dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(0.02),
precision=precision,
rngs=rngs,
)
self.q_proj = linear_class(hidden_size, num_attention_heads * self.head_dim)
self.k_proj = linear_class(hidden_size, num_key_value_heads * self.head_dim)
self.v_proj = linear_class(hidden_size, num_key_value_heads * self.head_dim)
self.o_proj = linear_class(num_attention_heads * self.head_dim, hidden_size)config = jfa.AttentionConfig(platform=jfa.Platform.TRITON, backend=jfa.Backend.GPU)
self.jfa2 = jfa.FlashAttention(config)
def __call__(
self,
hidden_states: chex.Array,
attention_mask: chex.Array,
causal: bool = True,
) -> tp.Tuple[chex.Array, chex.Array]:
batch_size, sequence_length = hidden_states.shape[:2]
query_states, key_states, value_states = (
self.q_proj(hidden_states),
self.k_proj(hidden_states),
self.v_proj(hidden_states),
)
qshape = (
batch_size,
sequence_length,
self.num_attention_heads,
self.head_dim,
)
kv_shape = (
batch_size,
sequence_length,
self.num_key_value_heads,
self.head_dim,
)
query_states = query_states.reshape(qshape)
key_states = key_states.reshape(kv_shape)
value_states = value_states.reshape(kv_shape)
attn_output = self.jfa2.forward(
query_states.astype(jnp.bfloat16),
key_states.astype(jnp.bfloat16),
value_states.astype(jnp.bfloat16),
jnp.where(attention_mask, 0, jnp.finfo(query_states).min).astype(jnp.bfloat16),
causal=causal,
)
attn_output = jnp.reshape(attn_output, (batch_size, sequence_length, -1))
attn_output = self.o_proj(attn_output)
return attn_output
```## Platform-Specific Examples
### Using JAX Backend
```python
jax_flash_attn = jfa.FlashAttention(
jfa.AttentionConfig(
platform=jfa.Platform.JAX,
backend=jfa.Backend.CPU, # Works on any hardware
)
)output = jax_flash_attn(query, key, value)
```### Using Pallas for TPU
```python
tpu_flash_attn = jfa.FlashAttention(
jfa.AttentionConfig(
platform=jfa.Platform.PALLAS,
backend=jfa.Backend.TPU,
)
)output = tpu_flash_attn(query, key, value)
```## Integration with JAX Transformations
```python
@jax.jit
def attention_forward(q, k, v, mask=None):
return flash_attention(
query=q,
key=k,
value=v,
attention_mask=mask,
causal=True
)# JIT-compiled function
fast_attention = attention_forward(query, key, value)# With gradient computation
def loss_fn(q, k, v):
attn_output = flash_attention(q, k, v, causal=True)
return jnp.mean(attn_output)grads = jax.grad(loss_fn)(query, key, value)
```## Limitations
- Triton platform is only available on NVIDIA GPUs.
- Some platform-backend combinations are not supported (see table above).
- Custom attention masks are not yet supported (use bias instead).## Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
## Citation
If you use this implementation in your research, please cite:
```bibtex
@software{jax_flash_attn2,
title = {JAX Flash Attention 2.0},
year = {2024},
url = {https://github.com/erfanzar/jax-flash-attn2}
}
```### refrence citations
```bibtex
@inproceedings{dao2022flashattention,
title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
year={2022}
}
@inproceedings{dao2023flashattention2,
title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
author={Dao, Tri},
booktitle={International Conference on Learning Representations (ICLR)},
year={2024}
}
```## Acknowledgments And Refrences
1. All of kernels are copied from [`EasyDeL`](https://github.com/erfanzar/Easydel)