https://github.com/kyegomez/mambaformer
Implementation of MambaFormer in Pytorch ++ Zeta from the paper: "Can Mamba Learn How to Learn? A Comparative Study on In-Context Learning Tasks"
https://github.com/kyegomez/mambaformer
ai attention attention-is-all-you-need attention-mechanisms mamba ml ssms transformer
Last synced: 6 months ago
JSON representation
Implementation of MambaFormer in Pytorch ++ Zeta from the paper: "Can Mamba Learn How to Learn? A Comparative Study on In-Context Learning Tasks"
- Host: GitHub
- URL: https://github.com/kyegomez/mambaformer
- Owner: kyegomez
- License: mit
- Created: 2024-04-03T23:57:50.000Z (over 1 year ago)
- Default Branch: main
- Last Pushed: 2025-03-17T18:55:17.000Z (7 months ago)
- Last Synced: 2025-03-24T11:55:24.660Z (7 months ago)
- Topics: ai, attention, attention-is-all-you-need, attention-mechanisms, mamba, ml, ssms, transformer
- Language: Python
- Homepage: https://discord.gg/7VckQVxvKk
- Size: 2.17 MB
- Stars: 20
- Watchers: 2
- Forks: 1
- Open Issues: 4
-
Metadata Files:
- Readme: README.md
- Funding: .github/FUNDING.yml
- License: LICENSE
Awesome Lists containing this project
README
[](https://discord.gg/qUtxnK2NMf)
# MambaFormer
Implementation of MambaFormer in Pytorch ++ Zeta from the paper: "Can Mamba Learn How to Learn? A Comparative Study on In-Context Learning Tasks"## install
`pip3 install mamba-former`## usage
```python
import torch
from mamba_former.main import MambaFormer# Forward pass example
x = torch.randint(1, 1000, (1, 100)) # Token
# Tokens are integers representing input data# Model
model = MambaFormer(
dim=512, # Dimension of the model
num_tokens=1000, # Number of unique tokens in the input data
depth=6, # Number of transformer layers
d_state=512, # Dimension of the transformer state
d_conv=128, # Dimension of the convolutional layer
heads=8, # Number of attention heads
dim_head=64, # Dimension of each attention head
return_tokens=True, # Whether to return the tokens in the output
)# Forward pass
out = model(x) # Perform a forward pass through the model# If training
# out = model(x, return_loss=True) # Perform a forward pass and calculate the loss# Print the output
print(out) # Print the output tensor
print(out.shape) # Print the shape of the output tensor```
# License
MIT