Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/birkhoffg/jax-dataloader
Pytorch-like dataloaders in JAX.
https://github.com/birkhoffg/jax-dataloader
dataloader dataset datasets deep-learning huggingface-datasets jax jax-dataloader pytorch tensorflow
Last synced: 15 days ago
JSON representation
Pytorch-like dataloaders in JAX.
- Host: GitHub
- URL: https://github.com/birkhoffg/jax-dataloader
- Owner: BirkhoffG
- License: apache-2.0
- Created: 2023-01-12T03:13:44.000Z (almost 2 years ago)
- Default Branch: main
- Last Pushed: 2024-04-02T21:48:37.000Z (9 months ago)
- Last Synced: 2024-04-24T04:11:38.788Z (9 months ago)
- Topics: dataloader, dataset, datasets, deep-learning, huggingface-datasets, jax, jax-dataloader, pytorch, tensorflow
- Language: Jupyter Notebook
- Homepage: https://birkhoffg.github.io/jax-dataloader/
- Size: 903 KB
- Stars: 37
- Watchers: 2
- Forks: 3
- Open Issues: 3
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Dataloader for JAX
![Python](https://img.shields.io/pypi/pyversions/jax-dataloader.svg)
![CI
status](https://github.com/BirkhoffG/jax-dataloader/actions/workflows/nbdev.yaml/badge.svg)
![Docs](https://github.com/BirkhoffG/jax-dataloader/actions/workflows/deploy.yaml/badge.svg)
![pypi](https://img.shields.io/pypi/v/jax-dataloader.svg) ![GitHub
License](https://img.shields.io/github/license/BirkhoffG/jax-dataloader.svg)## Overview
`jax_dataloader` brings *pytorch-like* dataloader API to `jax`. It
supports- **4 datasets to download and pre-process data**:
- [jax dataset](https://birkhoffg.github.io/jax-dataloader/dataset/)
- [huggingface datasets](https://github.com/huggingface/datasets)
- [pytorch
Dataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset)
- [tensorflow dataset](www.tensorflow.org/datasets)- **3 backends to iteratively load batches**:
- [jax
dataloader](https://birkhoffg.github.io/jax-dataloader/core.html#jax-dataloader)
- [pytorch
dataloader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)
- [tensorflow dataset](www.tensorflow.org/datasets)A minimum `jax-dataloader` example:
``` python
import jax_dataloader as jdljdl.manual_seed(1234) # Set the global seed to 1234 for reproducibility
dataloader = jdl.DataLoader(
dataset, # Can be a jdl.Dataset or pytorch or huggingface or tensorflow dataset
backend='jax', # Use 'jax' backend for loading data
batch_size=32, # Batch size
shuffle=True, # Shuffle the dataloader every iteration or not
drop_last=False, # Drop the last batch or not
)batch = next(iter(dataloader)) # iterate next batch
```## Installation
The latest `jax-dataloader` release can directly be installed from PyPI:
``` sh
pip install jax-dataloader
```or install directly from the repository:
``` sh
pip install git+https://github.com/BirkhoffG/jax-dataloader.git
```> [!NOTE]
>
> We keep `jax-dataloader`’s dependencies minimum, which only install
> `jax` and `plum-dispatch` (for backend dispatching) when installing.
> If you wish to use integration of [`pytorch`](https://pytorch.org/),
> huggingface [`datasets`](https://github.com/huggingface/datasets), or
> [`tensorflow`](https://www.tensorflow.org/), we highly recommend
> manually install those dependencies.
>
> You can also run `pip install jax-dataloader[all]` to install
> everything (not recommended).## Usage
[`jax_dataloader.core.DataLoader`](https://birkhoffg.github.io/jax-dataloader/core.html#dataloader)
follows similar API as the pytorch dataloader.- The `dataset` should be an object of the subclass of
`jax_dataloader.core.Dataset` or `torch.utils.data.Dataset` or (the
huggingface) `datasets.Dataset` or `tf.data.Dataset`.
- The `backend` should be one of `"jax"` or `"pytorch"` or
`"tensorflow"`. This argument specifies which backend dataloader to
load batches.Note that not every dataset is compatible with every backend. See the
compatibility table below:| | `jdl.Dataset` | `torch_data.Dataset` | `tf.data.Dataset` | `datasets.Dataset` |
|:---------------|:--------------|:---------------------|:------------------|:-------------------|
| `"jax"` | ✅ | ❌ | ❌ | ✅ |
| `"pytorch"` | ✅ | ✅ | ❌ | ✅ |
| `"tensorflow"` | ✅ | ❌ | ✅ | ✅ |### Using [`ArrayDataset`](https://birkhoffg.github.io/jax-dataloader/dataset.html#arraydataset)
The `jax_dataloader.core.ArrayDataset` is an easy way to wrap multiple
`jax.numpy.array` into one Dataset. For example, we can create an
[`ArrayDataset`](https://birkhoffg.github.io/jax-dataloader/dataset.html#arraydataset)
as follows:``` python
# Create features `X` and labels `y`
X = jnp.arange(100).reshape(10, 10)
y = jnp.arange(10)
# Create an `ArrayDataset`
arr_ds = jdl.ArrayDataset(X, y)
```This `arr_ds` can be loaded by *every* backends.
``` python
# Create a `DataLoader` from the `ArrayDataset` via jax backend
dataloader = jdl.DataLoader(arr_ds, 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(arr_ds, 'pytorch', batch_size=5, shuffle=True)
# Or we can use the tensorflow backend
dataloader = jdl.DataLoader(arr_ds, 'tensorflow', batch_size=5, shuffle=True)
```### Using Huggingface Datasets
The huggingface [datasets](https://github.com/huggingface/datasets) is a
morden library for downloading, pre-processing, and sharing datasets.
`jax_dataloader` supports directly passing the huggingface datasets.``` python
from datasets import load_dataset
```For example, We load the `"squad"` dataset from `datasets`:
``` python
hf_ds = load_dataset("squad")
```Then, we can use `jax_dataloader` to load batches of `hf_ds`.
``` python
# Create a `DataLoader` from the `datasets.Dataset` via jax backend
dataloader = jdl.DataLoader(hf_ds['train'], 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(hf_ds['train'], 'pytorch', batch_size=5, shuffle=True)
# Or we can use the tensorflow backend
dataloader = jdl.DataLoader(hf_ds['train'], 'tensorflow', batch_size=5, shuffle=True)
```### Using Pytorch Datasets
The [pytorch Dataset](https://pytorch.org/docs/stable/data.html) and its
ecosystems (e.g.,
[torchvision](https://pytorch.org/vision/stable/index.html),
[torchtext](https://pytorch.org/text/stable/index.html),
[torchaudio](https://pytorch.org/audio/stable/index.html)) supports many
built-in datasets. `jax_dataloader` supports directly passing the
pytorch Dataset.> [!NOTE]
>
> Unfortuantely, the [pytorch
> Dataset](https://pytorch.org/docs/stable/data.html) can only work with
> `backend=pytorch`. See the belowing example.``` python
from torchvision.datasets import MNIST
import numpy as np
```We load the MNIST dataset from `torchvision`. The `ToNumpy` object
transforms images to `numpy.array`.``` python
pt_ds = MNIST('/tmp/mnist/', download=True, transform=lambda x: np.array(x, dtype=float), train=False)
```This `pt_ds` can **only** be loaded via `"pytorch"` dataloaders.
``` python
dataloader = jdl.DataLoader(pt_ds, 'pytorch', batch_size=5, shuffle=True)
```### Using Tensowflow Datasets
`jax_dataloader` supports directly passing the [tensorflow
datasets](www.tensorflow.org/datasets).``` python
import tensorflow_datasets as tfds
import tensorflow as tf
```For instance, we can load the MNIST dataset from `tensorflow_datasets`
``` python
tf_ds = tfds.load('mnist', split='test', as_supervised=True)
```and use `jax_dataloader` for iterating the dataset.
``` python
dataloader = jdl.DataLoader(tf_ds, 'tensorflow', batch_size=5, shuffle=True)
```