https://github.com/juliagenai/flashattentionwrapper.jl
Just a simple wrapper for the Flash Attention operation.
https://github.com/juliagenai/flashattentionwrapper.jl
Last synced: about 1 month ago
JSON representation
Just a simple wrapper for the Flash Attention operation.
- Host: GitHub
- URL: https://github.com/juliagenai/flashattentionwrapper.jl
- Owner: JuliaGenAI
- License: mit
- Created: 2024-12-06T12:04:26.000Z (6 months ago)
- Default Branch: main
- Last Pushed: 2024-12-28T12:42:32.000Z (5 months ago)
- Last Synced: 2025-04-06T06:47:33.671Z (about 2 months ago)
- Language: Julia
- Homepage:
- Size: 13.7 KB
- Stars: 3
- Watchers: 4
- Forks: 0
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# FlashAttentionWrapper.jl
Just a simple wrapper for the [Flash Attention](https://github.com/Dao-AILab/flash-attention) operation.
## Installation
```julia
using FlashAttentionWrapperFlashAttentionWrapper.install()
```Note that by default it will install the latest version of FlashAttention.
## Example
```julia
using FlashAttentionWrapper# q, k, v are assumed to be 4d CuArray of size (head_dim, n_heads, seq_len, batch_size)
o = mha(q, k, v; kw...)
```Check the original [doc](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#how-to-use-flashattention) on the explanation of supported keyword arguments.
Backward is also supported:
```julia
using CUDA
using Zygoteo, back = Zygote.pullback(q, k, v) do q, k, v
mha(q, k, v)
endΔo = CUDA.randn(eltype(o), size(o))
Δq, Δk, Δv = back(Δo)
```If you'd like to use it with Lux.jl, here's a handy example:
```julia
using Luxhead_dim, n_head, seq_len, batch_size = 256, 8, 1024, 4
hidden_dim = head_dim * n_headx = CUDA.randn(Float16, (hidden_dim, seq_len, batch_size))
m = Chain(
BranchLayer(
Chain(
Dense(hidden_dim => hidden_dim, use_bias=false),
ReshapeLayer((head_dim, n_head, seq_len))
),
Chain(
Dense(hidden_dim => hidden_dim, use_bias=false),
ReshapeLayer((head_dim, n_head, seq_len))
),
Chain(
Dense(hidden_dim => hidden_dim, use_bias=false),
ReshapeLayer((head_dim, n_head, seq_len))
),
),
Attention(),
ReshapeLayer((hidden_dim, seq_len)),
Dense(hidden_dim => hidden_dim, use_bias=false),
)using Random
rng = Random.default_rng()
ps, st = LuxCore.setup(rng, m)
cu_ps = recursive_map(CuArray{Float16}, ps)o, _ = m(x, cu_ps, st)
```Or if you prefer Flux.jl:
```julia
using Fluxhead_dim, n_head, seq_len, batch_size = 256, 8, 1024, 4
hidden_dim = head_dim * n_headx = CUDA.randn(Float16, (hidden_dim, seq_len, batch_size))
m = Flux.Chain(
Flux.Parallel(
tuple,
Flux.Chain(
Flux.Dense(CUDA.randn(Float16, hidden_dim, hidden_dim), false),
x -> reshape(x, head_dim, n_head, seq_len, batch_size),
),
Flux.Chain(
Flux.Dense(CUDA.randn(Float16, hidden_dim, hidden_dim), false),
x -> reshape(x, head_dim, n_head, seq_len, batch_size),
),
Flux.Chain(
Flux.Dense(CUDA.randn(Float16, hidden_dim, hidden_dim), false),
x -> reshape(x, head_dim, n_head, seq_len, batch_size),
),
),
qkv -> reshape(mha(qkv...;), :, seq_len, batch_size),
Flux.Dense(CUDA.randn(Float16, hidden_dim, hidden_dim), false),
)m(x)
```## TODO List
- [ ] Add benchmark
- [ ] Support FlexAttention?
- [ ] Support [FlashInfer](https://github.com/flashinfer-ai/flashinfer)?