https://github.com/kyegomez/jamba
PyTorch Implementation of Jamba: "Jamba: A Hybrid Transformer-Mamba Language Model"
https://github.com/kyegomez/jamba
ai artificial-neural-networks attention-is-all-you-need attention-mechanism attention-mechanisms gpt ml ssm transformers
Last synced: 3 days ago
JSON representation
PyTorch Implementation of Jamba: "Jamba: A Hybrid Transformer-Mamba Language Model"
- Host: GitHub
- URL: https://github.com/kyegomez/jamba
- Owner: kyegomez
- License: mit
- Created: 2024-04-01T15:56:24.000Z (about 1 year ago)
- Default Branch: main
- Last Pushed: 2025-04-04T12:57:27.000Z (18 days ago)
- Last Synced: 2025-04-14T17:02:50.833Z (8 days ago)
- Topics: ai, artificial-neural-networks, attention-is-all-you-need, attention-mechanism, attention-mechanisms, gpt, ml, ssm, transformers
- Language: Python
- Homepage: https://discord.gg/7VckQVxvKk
- Size: 2.17 MB
- Stars: 166
- Watchers: 4
- Forks: 11
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- Funding: .github/FUNDING.yml
- License: LICENSE
Awesome Lists containing this project
README
[](https://discord.gg/qUtxnK2NMf)
# Jamba
PyTorch Implementation of Jamba: "Jamba: A Hybrid Transformer-Mamba Language Model"## install
`$ pip install jamba`## usage
```python
# Import the torch library, which provides tools for machine learning
import torch# Import the Jamba model from the jamba.model module
from jamba.model import Jamba# Create a tensor of random integers between 0 and 100, with shape (1, 100)
# This simulates a batch of tokens that we will pass through the model
x = torch.randint(0, 100, (1, 100))# Initialize the Jamba model with the specified parameters
# dim: dimensionality of the input data
# depth: number of layers in the model
# num_tokens: number of unique tokens in the input data
# d_state: dimensionality of the hidden state in the model
# d_conv: dimensionality of the convolutional layers in the model
# heads: number of attention heads in the model
# num_experts: number of expert networks in the model
# num_experts_per_token: number of experts used for each token in the input data
model = Jamba(
dim=512,
depth=6,
num_tokens=100,
d_state=256,
d_conv=128,
heads=8,
num_experts=8,
num_experts_per_token=2,
)# Perform a forward pass through the model with the input data
# This will return the model's predictions for each token in the input data
output = model(x)# Print the model's predictions
print(output)```
## Train
`python3 train.py`# License
MIT