An open API service indexing awesome lists of open source software.

https://github.com/pxl-th/nnop.jl

Pure Julia NN kernels.
https://github.com/pxl-th/nnop.jl

gpgpu gpu julia

Last synced: over 1 year ago
JSON representation

Pure Julia NN kernels.

Awesome Lists containing this project

README

          

# NNop.jl

Pure Julia NN kernels.

> [!WARNING]
> The package is in the early stages and is not yet fully ready.

## Ops

### Flash Attention

Implementation of [FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness](https://arxiv.org/abs/2205.14135).

```julia
E = 64
L = 4096
H, B = 4, 4

q = ROCArray(rand(Float32, E, L, H, B))
k = ROCArray(rand(Float32, E, L, H, B))
v = ROCArray(rand(Float32, E, L, H, B))

o = flash_attention(q, k, v)
```

#### Benchmarks:

For the problem size `(E=64, L=4096, H=4, B=4)`.

||Naїve attention|Flash Attention|
|-|-|-|
|Execution time|55.034 ms|18.490 ms|
|Peak memory usage|4.044 GiB|16.500 MiB|

#### Features:

- Forward & backward passes.
- Arbitrary sequence length.
- Arbitrary head sizes.
- FP32, FP16, BFP16 support.

In progress:

- [ ] Causal masking.
- [ ] Variable sequence length.

### Fused (online) Softmax

Implementation of [Online normalizer calculation for softmax](https://arxiv.org/abs/1805.02867).

```julia
x = ROCArray(ones(Float32, 8192, 1024))
y = online_softmax(x)
```

||Naїve Softmax|Online Softmax|
|-|-|-|
|Execution time|745.123 μs|61.600 μs|
|Peak memory usage|64.258 MiB|32.000 MiB|