https://github.com/kyegomez/mambaformer
Implementation of MambaFormer in Pytorch ++ Zeta from the paper: "Can Mamba Learn How to Learn? A Comparative Study on In-Context Learning Tasks"
https://github.com/kyegomez/mambaformer
ai attention attention-is-all-you-need attention-mechanisms mamba ml ssms transformer
Last synced: about 1 year ago
JSON representation
Implementation of MambaFormer in Pytorch ++ Zeta from the paper: "Can Mamba Learn How to Learn? A Comparative Study on In-Context Learning Tasks"
- Host: GitHub
- URL: https://github.com/kyegomez/mambaformer
- Owner: kyegomez
- License: mit
- Created: 2024-04-03T23:57:50.000Z (about 2 years ago)
- Default Branch: main
- Last Pushed: 2025-03-17T18:55:17.000Z (about 1 year ago)
- Last Synced: 2025-03-24T11:55:24.660Z (about 1 year ago)
- Topics: ai, attention, attention-is-all-you-need, attention-mechanisms, mamba, ml, ssms, transformer
- Language: Python
- Homepage: https://discord.gg/7VckQVxvKk
- Size: 2.17 MB
- Stars: 20
- Watchers: 2
- Forks: 1
- Open Issues: 4
-
Metadata Files:
- Readme: README.md
- Funding: .github/FUNDING.yml
- License: LICENSE
Awesome Lists containing this project
README
[](https://discord.gg/qUtxnK2NMf)
# MambaFormer
Implementation of MambaFormer in Pytorch ++ Zeta from the paper: "Can Mamba Learn How to Learn? A Comparative Study on In-Context Learning Tasks"
## install
`pip3 install mamba-former`
## usage
```python
import torch
from mamba_former.main import MambaFormer
# Forward pass example
x = torch.randint(1, 1000, (1, 100)) # Token
# Tokens are integers representing input data
# Model
model = MambaFormer(
dim=512, # Dimension of the model
num_tokens=1000, # Number of unique tokens in the input data
depth=6, # Number of transformer layers
d_state=512, # Dimension of the transformer state
d_conv=128, # Dimension of the convolutional layer
heads=8, # Number of attention heads
dim_head=64, # Dimension of each attention head
return_tokens=True, # Whether to return the tokens in the output
)
# Forward pass
out = model(x) # Perform a forward pass through the model
# If training
# out = model(x, return_loss=True) # Perform a forward pass and calculate the loss
# Print the output
print(out) # Print the output tensor
print(out.shape) # Print the shape of the output tensor
```
# License
MIT