Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/MzeroMiko/mamba-mini
An efficient pytorch implementation of selective scan in one file, works with both cpu and gpu, with corresponding mathematical derivation. It is probably the code which is the most close to selective_scan_cuda in mamba.
https://github.com/MzeroMiko/mamba-mini
efficient mamba pytorch selective-scan
Last synced: 3 months ago
JSON representation
An efficient pytorch implementation of selective scan in one file, works with both cpu and gpu, with corresponding mathematical derivation. It is probably the code which is the most close to selective_scan_cuda in mamba.
- Host: GitHub
- URL: https://github.com/MzeroMiko/mamba-mini
- Owner: MzeroMiko
- Created: 2024-02-05T12:24:37.000Z (11 months ago)
- Default Branch: main
- Last Pushed: 2024-03-04T10:37:07.000Z (10 months ago)
- Last Synced: 2024-08-01T04:02:11.120Z (6 months ago)
- Topics: efficient, mamba, pytorch, selective-scan
- Language: Python
- Homepage:
- Size: 1.16 MB
- Stars: 61
- Watchers: 3
- Forks: 0
- Open Issues: 4
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
- Awesome-state-space-models - Mamba-mini
README
# mamba-mini
An efficient implementation of selective scan in one file, works with both cpu and gpu, with corresponding mathematical derivation. It is probably the code which is the most close to selective_scan_cuda in mamba.### update!
* **`20240304: New implementation with new derivations!`** we now support a new approach to implement selective_scan chunk-parallely: [`selective_scan_easyv3`](./test_selective_scan_easy.py). It is faster than `selective_scan_easy` when `d_state=1`, but still slower than `mamba_ssm` with cuda. We would implement it in `triton` and test the speed in the future.### mathematical derivation to `chunk-naive version`
code is in [`selective_scan_easy`](./test_selective_scan_easy.py) and [`SelectiveScanEasy`](./test_selective_scan_easy.py).
![image](./assets/derivation.png)### mathematical derivation to `chunk-parallel version`
This is the chunk parallel version of selective scan, with support to some different branches.
code is in [`selective_scan_easyv3`](./test_selective_scan_easy.py).
![image](./assets/derivation_general.png)
![image](./assets/derivation_wdk.png)
![image](./assets/derivation_wdv.png)
![image](./assets/derivation_dk1.png)### naive code
```python
import torch
def selective_scan_easy(us, dts, As, Bs, Cs, Ds, delta_bias=None, delta_softplus=False, return_last_state=False, chunksize=64):
"""
# B: batch_size, G: groups, D: dim, N: state dim, L: seqlen
us: B, G * D, L
dts: B, G * D, L
As: G * D, N
Bs: B, G, N, L
Cs: B, G, N, L
Ds: G * D
delta_bias: G * D
# chunksize can be any as you like. But as the chunksize raises, hs may get None, as exp(sum(delta) A) is really small
"""
def selective_scan_chunk(us, dts, As, Bs, Cs, hprefix):
"""
partial(h) / partial(t) = Ah + Bu; y = Ch + Du;
=> partial(h*exp(-At)) / partial(t) = Bu*exp(-At);
=> h_t = h_0 + sum_{0}_{t}_{Bu*exp(A(t-v)) dv};
=> h_b = exp(A(dt_a + ... + dt_{b-1})) * (h_a + sum_{a}_{b-1}_{Bu*exp(-A(dt_a + ... + dt_i)) dt_i});
y_i = C_i*h_i + D*u_i
"""
"""
us, dts: (L, B, G, D) # L is chunk_size
As: (G, D, N)
Bs, Cs: (L, B, G, N)
Ds: (G, D)
hprefix: (B, G, D, N)
"""
ts = dts.cumsum(dim=0)
Ats = torch.einsum("gdn,lbgd->lbgdn", As, ts).exp()
scale = Ats[-1].detach()
rAts = Ats / scale
duts = dts * us
dtBus = torch.einsum("lbgd,lbgn->lbgdn", duts, Bs)
hs_tmp = rAts * (dtBus / rAts).cumsum(dim=0)
hs = hs_tmp + Ats * hprefix.unsqueeze(0)
ys = torch.einsum("lbgn,lbgdn->lbgd", Cs, hs)
return ys, hs
inp_dtype = us.dtype
has_D = Ds is not Nonedts = dts.float()
if delta_bias is not None:
dts = dts + delta_bias.view(1, -1, 1).float()
if delta_softplus:
dts = torch.nn.functional.softplus(dts)
if len(Bs.shape) == 3:
Bs = Bs.unsqueeze(1)
if len(Cs.shape) == 3:
Cs = Cs.unsqueeze(1)
B, G, N, L = Bs.shape
us = us.view(B, G, -1, L).permute(3, 0, 1, 2).float()
dts = dts.view(B, G, -1, L).permute(3, 0, 1, 2).float()
As = As.view(G, -1, N).float()
Bs = Bs.permute(3, 0, 1, 2).float()
Cs = Cs.permute(3, 0, 1, 2).float()
Ds = Ds.view(G, -1).float() if has_D else None
D = As.shape[1]
oys = []
# ohs = []
hprefix = us.new_zeros((B, G, D, N), dtype=torch.float)
for i in range(0, L - 1, chunksize):
ys, hs = selective_scan_chunk(
us[i:i + chunksize], dts[i:i + chunksize],
As, Bs[i:i + chunksize], Cs[i:i + chunksize], hprefix,
)
oys.append(ys)
# ohs.append(hs)
hprefix = hs[-1]oys = torch.cat(oys, dim=0)
# ohs = torch.cat(ohs, dim=0)
if has_D:
oys = oys + Ds * us
oys = oys.permute(1, 2, 3, 0).view(B, -1, L)
oys = oys.to(inp_dtype)
# hprefix = hprefix.to(inp_dtype)return oys if not return_last_state else (oys, hprefix.view(B, G * D, N))
```
### to test
```bash
pytest test_selective_scan.py
```