{"id":23324399,"url":"https://github.com/birkhoffg/jax-dataloader","last_synced_at":"2025-08-22T18:32:33.387Z","repository":{"id":115052094,"uuid":"587981258","full_name":"BirkhoffG/jax-dataloader","owner":"BirkhoffG","description":"Pytorch-like dataloaders for JAX.","archived":false,"fork":false,"pushed_at":"2025-05-26T23:29:37.000Z","size":1064,"stargazers_count":81,"open_issues_count":6,"forks_count":3,"subscribers_count":3,"default_branch":"main","last_synced_at":"2025-05-27T00:28:11.947Z","etag":null,"topics":["dataloader","dataset","datasets","deep-learning","huggingface-datasets","jax","jax-dataloader","pytorch","tensorflow"],"latest_commit_sha":null,"homepage":"https://birkhoffg.github.io/jax-dataloader/","language":"Jupyter Notebook","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/BirkhoffG.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null,"zenodo":null}},"created_at":"2023-01-12T03:13:44.000Z","updated_at":"2025-05-17T16:27:21.000Z","dependencies_parsed_at":"2023-11-28T06:31:23.660Z","dependency_job_id":"dc72dea5-feda-4ff1-8b86-f72646a821bf","html_url":"https://github.com/BirkhoffG/jax-dataloader","commit_stats":null,"previous_names":[],"tags_count":8,"template":false,"template_full_name":null,"purl":"pkg:github/BirkhoffG/jax-dataloader","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/BirkhoffG%2Fjax-dataloader","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/BirkhoffG%2Fjax-dataloader/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/BirkhoffG%2Fjax-dataloader/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/BirkhoffG%2Fjax-dataloader/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/BirkhoffG","download_url":"https://codeload.github.com/BirkhoffG/jax-dataloader/tar.gz/refs/heads/main","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/BirkhoffG%2Fjax-dataloader/sbom","scorecard":null,"host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":271681603,"owners_count":24802078,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","status":"online","status_checked_at":"2025-08-22T02:00:08.480Z","response_time":65,"last_error":null,"robots_txt_status":"success","robots_txt_updated_at":"2025-07-24T06:49:26.215Z","robots_txt_url":"https://github.com/robots.txt","online":true,"can_crawl_api":true,"host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["dataloader","dataset","datasets","deep-learning","huggingface-datasets","jax","jax-dataloader","pytorch","tensorflow"],"created_at":"2024-12-20T18:14:17.896Z","updated_at":"2025-08-22T18:32:33.369Z","avatar_url":"https://github.com/BirkhoffG.png","language":"Jupyter Notebook","readme":"# Dataloader for JAX\n\n\n\u003c!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! --\u003e\n\n![Python](https://img.shields.io/pypi/pyversions/jax-dataloader.svg)\n![CI\nstatus](https://github.com/BirkhoffG/jax-dataloader/actions/workflows/nbdev.yaml/badge.svg)\n![Docs](https://github.com/BirkhoffG/jax-dataloader/actions/workflows/deploy.yaml/badge.svg)\n![pypi](https://img.shields.io/pypi/v/jax-dataloader.svg) ![GitHub\nLicense](https://img.shields.io/github/license/BirkhoffG/jax-dataloader.svg)\n\u003ca href=\"https://static.pepy.tech/badge/jax-dataloader\"\u003e\u003cimg src=\"https://static.pepy.tech/badge/jax-dataloader\" alt=\"Downloads\"\u003e\u003c/a\u003e\n\n[**Overview**](#overview) \\| [**Installation**](#installation) \\|\n[**Usage**](#usage) \\|\n[**Documentation**](https://birkhoffg.github.io/jax-dataloader)\n\n## Overview\n\n`jax_dataloader` brings *pytorch-like* dataloader API to `jax`. It\nsupports\n\n- **4 datasets to download and pre-process data**:\n\n  - [jax dataset](https://birkhoffg.github.io/jax-dataloader/dataset/)\n  - [huggingface datasets](https://github.com/huggingface/datasets)\n  - [pytorch\n    Dataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset)\n  - [tensorflow dataset](www.tensorflow.org/datasets)\n\n- **3 backends to iteratively load batches**:\n\n  - [jax\n    dataloader](https://birkhoffg.github.io/jax-dataloader/core.html#jax-dataloader)\n  - [pytorch\n    dataloader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)\n  - [tensorflow dataset](www.tensorflow.org/datasets)\n\nA minimum `jax-dataloader` example:\n\n``` python\nimport jax_dataloader as jdl\n\njdl.manual_seed(1234) # Set the global seed to 1234 for reproducibility\n\ndataloader = jdl.DataLoader(\n    dataset, # Can be a jdl.Dataset or pytorch or huggingface or tensorflow dataset\n    backend='jax', # Use 'jax' backend for loading data\n    batch_size=32, # Batch size \n    shuffle=True, # Shuffle the dataloader every iteration or not\n    drop_last=False, # Drop the last batch or not\n    generator=jdl.Generator() # Control the randomness of this dataloader \n)\n\nbatch = next(iter(dataloader)) # iterate next batch\n```\n\n## Installation\n\nThe latest `jax-dataloader` release can directly be installed from PyPI:\n\n``` sh\npip install jax-dataloader\n```\n\nor install directly from the repository:\n\n``` sh\npip install git+https://github.com/BirkhoffG/jax-dataloader.git\n```\n\n\u003e [!NOTE]\n\u003e\n\u003e We keep `jax-dataloader`’s dependencies minimum, which only install\n\u003e `jax` and `plum-dispatch` (for backend dispatching) when installing.\n\u003e If you wish to use integration of [`pytorch`](https://pytorch.org/),\n\u003e huggingface [`datasets`](https://github.com/huggingface/datasets), or\n\u003e [`tensorflow`](https://www.tensorflow.org/), we highly recommend\n\u003e manually install those dependencies.\n\u003e\n\u003e You can also run `pip install jax-dataloader[all]` to install\n\u003e everything (not recommended).\n\n## Usage\n\n[`jax_dataloader.core.DataLoader`](https://birkhoffg.github.io/jax-dataloader/core.html#dataloader)\nfollows similar API as the pytorch dataloader.\n\n- The `dataset` should be an object of the subclass of\n  `jax_dataloader.core.Dataset` or `torch.utils.data.Dataset` or (the\n  huggingface) `datasets.Dataset` or `tf.data.Dataset`.\n- The `backend` should be one of `\"jax\"` or `\"pytorch\"` or\n  `\"tensorflow\"`. This argument specifies which backend dataloader to\n  load batches.\n\nNote that not every dataset is compatible with every backend. See the\ncompatibility table below:\n\n|                | `jdl.Dataset` | `torch_data.Dataset` | `tf.data.Dataset` | `datasets.Dataset` |\n|:---------------|:--------------|:---------------------|:------------------|:-------------------|\n| `\"jax\"`        | ✅            | ❌                   | ❌                | ✅                 |\n| `\"pytorch\"`    | ✅            | ✅                   | ❌                | ✅                 |\n| `\"tensorflow\"` | ✅            | ❌                   | ✅                | ✅                 |\n\n### Using [`ArrayDataset`](https://birkhoffg.github.io/jax-dataloader/dataset.html#arraydataset)\n\nThe `jax_dataloader.core.ArrayDataset` is an easy way to wrap multiple\n`jax.numpy.array` into one Dataset. For example, we can create an\n[`ArrayDataset`](https://birkhoffg.github.io/jax-dataloader/dataset.html#arraydataset)\nas follows:\n\n``` python\n# Create features `X` and labels `y`\nX = jnp.arange(100).reshape(10, 10)\ny = jnp.arange(10)\n# Create an `ArrayDataset`\narr_ds = jdl.ArrayDataset(X, y)\n```\n\nThis `arr_ds` can be loaded by *every* backends.\n\n``` python\n# Create a `DataLoader` from the `ArrayDataset` via jax backend\ndataloader = jdl.DataLoader(arr_ds, 'jax', batch_size=5, shuffle=True)\n# Or we can use the pytorch backend\ndataloader = jdl.DataLoader(arr_ds, 'pytorch', batch_size=5, shuffle=True)\n# Or we can use the tensorflow backend\ndataloader = jdl.DataLoader(arr_ds, 'tensorflow', batch_size=5, shuffle=True)\n```\n\n### Using Huggingface Datasets\n\nThe huggingface [datasets](https://github.com/huggingface/datasets) is a\nmorden library for downloading, pre-processing, and sharing datasets.\n`jax_dataloader` supports directly passing the huggingface datasets.\n\n``` python\nfrom datasets import load_dataset\n```\n\nFor example, We load the `\"squad\"` dataset from `datasets`:\n\n``` python\nhf_ds = load_dataset(\"squad\")\n```\n\nThen, we can use `jax_dataloader` to load batches of `hf_ds`.\n\n``` python\n# Create a `DataLoader` from the `datasets.Dataset` via jax backend\ndataloader = jdl.DataLoader(hf_ds['train'], 'jax', batch_size=5, shuffle=True)\n# Or we can use the pytorch backend\ndataloader = jdl.DataLoader(hf_ds['train'], 'pytorch', batch_size=5, shuffle=True)\n# Or we can use the tensorflow backend\ndataloader = jdl.DataLoader(hf_ds['train'], 'tensorflow', batch_size=5, shuffle=True)\n```\n\n### Using Pytorch Datasets\n\nThe [pytorch Dataset](https://pytorch.org/docs/stable/data.html) and its\necosystems (e.g.,\n[torchvision](https://pytorch.org/vision/stable/index.html),\n[torchtext](https://pytorch.org/text/stable/index.html),\n[torchaudio](https://pytorch.org/audio/stable/index.html)) supports many\nbuilt-in datasets. `jax_dataloader` supports directly passing the\npytorch Dataset.\n\n\u003e [!NOTE]\n\u003e\n\u003e Unfortuantely, the [pytorch\n\u003e Dataset](https://pytorch.org/docs/stable/data.html) can only work with\n\u003e `backend=pytorch`. See the belowing example.\n\n``` python\nfrom torchvision.datasets import MNIST\nimport numpy as np\n```\n\nWe load the MNIST dataset from `torchvision`. The `ToNumpy` object\ntransforms images to `numpy.array`.\n\n``` python\npt_ds = MNIST('/tmp/mnist/', download=True, transform=lambda x: np.array(x, dtype=float), train=False)\n```\n\nThis `pt_ds` can **only** be loaded via `\"pytorch\"` dataloaders.\n\n``` python\ndataloader = jdl.DataLoader(pt_ds, 'pytorch', batch_size=5, shuffle=True)\n```\n\n### Using Tensowflow Datasets\n\n`jax_dataloader` supports directly passing the [tensorflow\ndatasets](www.tensorflow.org/datasets).\n\n``` python\nimport tensorflow_datasets as tfds\nimport tensorflow as tf\n```\n\nFor instance, we can load the MNIST dataset from `tensorflow_datasets`\n\n``` python\ntf_ds = tfds.load('mnist', split='test', as_supervised=True)\n```\n\nand use `jax_dataloader` for iterating the dataset.\n\n``` python\ndataloader = jdl.DataLoader(tf_ds, 'tensorflow', batch_size=5, shuffle=True)\n```\n","funding_links":[],"categories":[],"sub_categories":[],"project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fbirkhoffg%2Fjax-dataloader","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fbirkhoffg%2Fjax-dataloader","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fbirkhoffg%2Fjax-dataloader/lists"}