https://github.com/google/grain
Library for reading and processing ML training data.
https://github.com/google/grain
data-pr jax machine-learning python
Last synced: 11 days ago
JSON representation
Library for reading and processing ML training data.
- Host: GitHub
- URL: https://github.com/google/grain
- Owner: google
- License: apache-2.0
- Created: 2022-08-05T08:29:51.000Z (over 3 years ago)
- Default Branch: main
- Last Pushed: 2026-01-12T16:20:53.000Z (13 days ago)
- Last Synced: 2026-01-12T21:57:14.050Z (12 days ago)
- Topics: data-pr, jax, machine-learning, python
- Language: Python
- Homepage: https://google-grain.readthedocs.io
- Size: 5.38 MB
- Stars: 648
- Watchers: 12
- Forks: 61
- Open Issues: 131
-
Metadata Files:
- Readme: README.md
- Changelog: CHANGELOG.md
- Contributing: CONTRIBUTING.md
- License: LICENSE
Awesome Lists containing this project
README
# Grain - Feeding JAX Models
[](https://github.com/google/grain/actions/workflows/tests.yml)
[](https://pypi.org/project/grain/)
[**Installation**](#installation)
| [**Quickstart**](#quickstart)
| [**Reference docs**](https://google-grain.readthedocs.io/en/latest/)
| [**Change logs**](https://google-grain.readthedocs.io/en/latest/changelog.html)
Grain is a Python library for reading and processing data for training and
evaluating JAX models. It is flexible, fast and deterministic.
Grain allows to define data processing steps in a simple declarative way:
```python
import grain
dataset = (
grain.MapDataset.source([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
.shuffle(seed=42) # Shuffles elements globally.
.map(lambda x: x+1) # Maps each element.
.batch(batch_size=2) # Batches consecutive elements.
)
for batch in dataset:
# Training step.
```
Grain is designed to work with JAX models but it does not require JAX to run
and can be used with other frameworks as well.
## Installation
Grain is available on [PyPI](https://pypi.org/project/grain/) and can be
installed with `pip install grain`.
### Supported platforms
Grain does not directly use GPU or TPU in its transformations, the processing
within Grain will be done on the CPU by default.
| | Linux | Mac | Windows |
|---------|---------|---------|---------|
| x86_64 | yes | no | yes |
| aarch64 | yes | yes | n/a |
## Quickstart
- [Basic `Dataset` tutorial](https://google-grain.readthedocs.io/en/latest/tutorials/dataset_basic_tutorial.html)
## Citing Grain
To cite this repository:
```
@software{grain2023github,
author = {Marvin Ritter and Ihor Indyk and Aayush Singh and Andrew Audibert and Anoosha Seelam and Camelia Hanes and Eric Lau and Jacek Olesiak and Jiyang Kang and Xihui Wu},
title = {{Grain} - Feeding JAX Models},
url = {http://github.com/google/grain},
version = {0.2.12},
year = {2023},
}
```
The version number is intended to be that from [pyproject.toml](https://github.com/google/grain/blob/main/pyproject.toml), and the year corresponds to the project's open-source release.
## Existing users
Grain is used by [MaxText](https://github.com/google/maxtext/tree/main),
[Gemma](https://github.com/google-deepmind/gemma),
[kauldron](https://github.com/google-research/kauldron),
[maxdiffusion](https://github.com/AI-Hypercomputer/maxdiffusion) and multiple
internal Google projects.