https://github.com/kyegomez/mlxtransformer
Simple Implementation of a Transformer in the new framework MLX by Apple
https://github.com/kyegomez/mlxtransformer
artificial-intelligence gpt4 machine-learning multi-modal multi-modality
Last synced: 7 months ago
JSON representation
Simple Implementation of a Transformer in the new framework MLX by Apple
- Host: GitHub
- URL: https://github.com/kyegomez/mlxtransformer
- Owner: kyegomez
- License: mit
- Created: 2023-12-06T06:43:50.000Z (over 2 years ago)
- Default Branch: main
- Last Pushed: 2024-11-18T02:14:59.000Z (over 1 year ago)
- Last Synced: 2025-04-19T20:17:01.847Z (about 1 year ago)
- Topics: artificial-intelligence, gpt4, machine-learning, multi-modal, multi-modality
- Language: Python
- Homepage: https://discord.gg/GYbXvDGevY
- Size: 2.18 MB
- Stars: 20
- Watchers: 3
- Forks: 1
- Open Issues: 2
-
Metadata Files:
- Readme: README.md
- Funding: .github/FUNDING.yml
- License: LICENSE
Awesome Lists containing this project
README
[](https://discord.gg/qUtxnK2NMf)
# MLX Transformer
Implementation of high performance transformer modules in MLX by apple
# Install
`pip3 install --upgrade mlx-transformer`
## Usage
```python
from mlx_transformer.main import Transformer
from mlx.core.random.randint import randint
model = Transformer(
vocab_size=10000,
depth=12,
dim = 512,
heads = 8,
)
# Define the lower and upper bounds of the interval
low = 0
high = 10
# Generate a single random integer within the interval [low, high)
rand_int = randint(low, high)
print(rand_int) # Output: a random integer between 0 and 9
# Generate a random array of integers within the interval [low, high)
shape = [1, 10000, 512] # Shape of the output array
q = randint(low, high, shape)
k = randint(low, high, shape)
v = randint(low, high, shape)
# Use the random array to perform a forward pass through the model
output = model(q, k, v)
print(output.shape)
```
# License
MIT