An open API service indexing awesome lists of open source software.

https://github.com/kyegomez/jamba

PyTorch Implementation of Jamba: "Jamba: A Hybrid Transformer-Mamba Language Model"
https://github.com/kyegomez/jamba

ai artificial-neural-networks attention-is-all-you-need attention-mechanism attention-mechanisms gpt ml ssm transformers

Last synced: 6 months ago
JSON representation

PyTorch Implementation of Jamba: "Jamba: A Hybrid Transformer-Mamba Language Model"

Awesome Lists containing this project

README

          

[![Multi-Modality](agorabanner.png)](https://discord.gg/qUtxnK2NMf)

# Jamba
PyTorch Implementation of Jamba: "Jamba: A Hybrid Transformer-Mamba Language Model"

## install
`$ pip install jamba`

## usage

```python
# Import the torch library, which provides tools for machine learning
import torch

# Import the Jamba model from the jamba.model module
from jamba.model import Jamba

# Create a tensor of random integers between 0 and 100, with shape (1, 100)
# This simulates a batch of tokens that we will pass through the model
x = torch.randint(0, 100, (1, 100))

# Initialize the Jamba model with the specified parameters
# dim: dimensionality of the input data
# depth: number of layers in the model
# num_tokens: number of unique tokens in the input data
# d_state: dimensionality of the hidden state in the model
# d_conv: dimensionality of the convolutional layers in the model
# heads: number of attention heads in the model
# num_experts: number of expert networks in the model
# num_experts_per_token: number of experts used for each token in the input data
model = Jamba(
dim=512,
depth=6,
num_tokens=100,
d_state=256,
d_conv=128,
heads=8,
num_experts=8,
num_experts_per_token=2,
)

# Perform a forward pass through the model with the input data
# This will return the model's predictions for each token in the input data
output = model(x)

# Print the model's predictions
print(output)

```

## Train
`python3 train.py`

# License
MIT