{"id":19535474,"url":"https://github.com/gordicaleksa/get-started-with-jax","last_synced_at":"2025-04-04T13:11:10.676Z","repository":{"id":37812399,"uuid":"422339274","full_name":"gordicaleksa/get-started-with-JAX","owner":"gordicaleksa","description":"The purpose of this repo is to make it easy to get started with JAX, Flax, and Haiku. It contains my \"Machine Learning with JAX\" series of tutorials (YouTube videos and Jupyter Notebooks) as well as the content I found useful while learning about the JAX ecosystem.","archived":false,"fork":false,"pushed_at":"2023-11-29T17:21:14.000Z","size":1864,"stargazers_count":711,"open_issues_count":1,"forks_count":105,"subscribers_count":9,"default_branch":"main","last_synced_at":"2025-03-28T12:06:35.617Z","etag":null,"topics":["deep-learning","flax","haiku","jax","jupyter","lax","learn-jax","machine-learning","numpy","optax","python","tutorial","xla"],"latest_commit_sha":null,"homepage":"https://www.youtube.com/c/TheAIEpiphany/","language":"Jupyter Notebook","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/gordicaleksa.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":null,"code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null}},"created_at":"2021-10-28T20:00:00.000Z","updated_at":"2025-03-28T08:31:41.000Z","dependencies_parsed_at":"2024-01-07T10:56:58.924Z","dependency_job_id":"42ce3f7e-85e1-422c-a9df-c2aa2a1dbf6f","html_url":"https://github.com/gordicaleksa/get-started-with-JAX","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/gordicaleksa%2Fget-started-with-JAX","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/gordicaleksa%2Fget-started-with-JAX/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/gordicaleksa%2Fget-started-with-JAX/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/gordicaleksa%2Fget-started-with-JAX/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/gordicaleksa","download_url":"https://codeload.github.com/gordicaleksa/get-started-with-JAX/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":247182344,"owners_count":20897379,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["deep-learning","flax","haiku","jax","jupyter","lax","learn-jax","machine-learning","numpy","optax","python","tutorial","xla"],"created_at":"2024-11-11T02:18:49.308Z","updated_at":"2025-04-04T13:11:10.630Z","avatar_url":"https://github.com/gordicaleksa.png","language":"Jupyter Notebook","funding_links":[],"categories":[],"sub_categories":[],"readme":"## Get started with JAX! :computer: :zap:\n\nThe goal of this repo is to make it easier to get started with [JAX](https://github.com/google/jax), [Flax](https://github.com/google/flax), and [Haiku](https://github.com/deepmind/dm-haiku)!\n\n`JAX` ecosystem is becoming an increasingly popular alternative to `PyTorch` and `TensorFlow`. :sunglasses:\n\n\u003cbr/\u003e\n\u003cbr/\u003e\n\n\u003cp align=\"center\"\u003e\n\u003cimg src=\"readme_pics/jax_logo.png\" width=\"300\"/\u003e\n\u003c/p\u003e\n\n\u003cbr/\u003e\n\u003cbr/\u003e\n\n*Note: I'm only going to recommend content that I've personally analyzed and found useful here. \nIf you want a comprehensive list check out the [awesome-jax repo](https://github.com/n2cholas/awesome-jax).*\n\n## Table of Contents\n  * [Machine Learning with JAX](#my-machine-learning-with-jax-tutorials)\n    + [Tutorial #1: From Zero to Hero](#tutorial-1-from-zero-to-hero)\n    + [Tutorial #2: From Hero to Hero Pro+](#tutorial-2-from-hero-to-heropro)\n    + [Tutorial #3: Coding a Neural Network from Scratch in Pure JAX](#tutorial-3-building-a-neural-network-from-scratch)\n    + [Tutorial #4: Flax From Zero to Hero](#tutorial-4-machine-learning-with-flax---from-zero-to-hero)\n    + [Tutorial #5: Haiku From Zero to Hero (coming soon)](#tutorial-5-coming-up-machine-learning-with-haiku---from-zero-to-hero)\n  * [Other useful JAX resources](#other-useful-content)\n\n## My Machine Learning with JAX Tutorials\n\n*Tip on how to use notebooks: just open the notebook directly in Google Colab \n(you'll see a button on top of the Jupyter file which will direct you to Colab). \nThis way you can avoid having to setup the Python env! (This was especially convenient for me since I'm on Windows which is still not supported)*\n\n### Tutorial #1: From Zero to Hero\n\nIn this video, we start from the basics and then gradually dig into the nitty-gritty details\nof `jit`, `grad`, `vmap`, and various other idiosyncrasies of JAX.\n\n[YouTube Video (Tutorial #1)](https://youtu.be/SstuvS-tVc0) \u003cbr/\u003e\n[Accompanying Jupyter Notebook](https://github.com/gordicaleksa/get-started-with-JAX/blob/main/Tutorial_1_JAX_Zero2Hero_Colab.ipynb) \u003cbr/\u003e\n\n\u003cp align=\"left\"\u003e\n\u003ca href=\"https://www.youtube.com/watch?v=SstuvS-tVc0\" target=\"_blank\"\u003e\u003cimg src=\"https://img.youtube.com/vi/SstuvS-tVc0/0.jpg\" \nalt=\"JAX from zero to hero!\" width=\"480\" height=\"360\" border=\"10\" /\u003e\u003c/a\u003e\n\u003c/p\u003e\n\n### Tutorial #2: From Hero to HeroPro+\n\nIn this video, we learn all additional components needed to train ML models (such as NNs) on multiple machines!\nWe'll train a simple MLP model and we'll even train an ML model on 8 TPU cores!\n\n[YouTube Video (Tutorial #2)](https://www.youtube.com/watch?v=CQQaifxuFcs) \u003cbr/\u003e\n[Accompanying Jupyter Notebook](https://github.com/gordicaleksa/get-started-with-JAX/blob/main/Tutorial_2_JAX_HeroPro%2B_Colab.ipynb) \u003cbr/\u003e\n\n\u003cp align=\"left\"\u003e\n\u003ca href=\"https://www.youtube.com/watch?v=CQQaifxuFcs\" target=\"_blank\"\u003e\u003cimg src=\"https://img.youtube.com/vi/CQQaifxuFcs/0.jpg\" \nalt=\"JAX from Hero to HeroPro+!\" width=\"480\" height=\"360\" border=\"10\" /\u003e\u003c/a\u003e\n\u003c/p\u003e\n\n### Tutorial #3: Building a Neural Network from Scratch\n\nWatch me code a Neural Network from scratch! :partying_face: In this 3rd video of the JAX tutorials series.\n\nIn this video, I build an [MLP](https://en.wikipedia.org/wiki/Multilayer_perceptron) and train it as a classifier on MNIST \nusing PyTorch's data loader (although it's trivial to use a more complex dataset) - all this in \"pure\" JAX (no Flax/Haiku/Optax).\n\nI then do an additional analysis:\n* Visualize MLP's learned weights\n* Visualize embeddings of a batch of images using t-SNE\n* Finally, I analyze whether we have too many dead ReLU neurons in our network\n\n[YouTube Video (Tutorial #3)](https://www.youtube.com/watch?v=6_PqUPxRmjY) \u003cbr/\u003e\n[Accompanying Jupyter Notebook](https://github.com/gordicaleksa/get-started-with-JAX/blob/main/Tutorial_3_JAX_Neural_Network_from_Scratch_Colab.ipynb) (Note: I'll soon refactor it but I'll link the original)\u003cbr/\u003e\n\n\u003cp align=\"left\"\u003e\n\u003ca href=\"https://www.youtube.com/watch?v=6_PqUPxRmjY\" target=\"_blank\"\u003e\u003cimg src=\"https://img.youtube.com/vi/6_PqUPxRmjY/0.jpg\" \nalt=\"Building a Neural Network from Scratch in pure JAX!\" width=\"480\" height=\"360\" border=\"10\" /\u003e\u003c/a\u003e\n\u003c/p\u003e\n\n---\n\n### Tutorial #4: Machine Learning with Flax - From Zero to Hero\n\nIn this video, I cover everything you need to know to get started with [Flax](https://github.com/google/flax)!\n\nWe cover `init`, `apply`, `TrainState`, etc. and other idiosyncrasies like the usage of `mutable` and `rngs` keywords.\n\n[YouTube Video (Tutorial #4)](https://www.youtube.com/watch?v=5eUSmJvK8WA) \u003cbr/\u003e\n[Accompanying Jupyter Notebook](https://github.com/gordicaleksa/get-started-with-JAX/blob/main/Tutorial_4_Flax_Zero2Hero_Colab.ipynb) \u003cbr/\u003e\n\n\u003cp align=\"left\"\u003e\n\u003ca href=\"https://www.youtube.com/watch?v=5eUSmJvK8WA\" target=\"_blank\"\u003e\u003cimg src=\"https://img.youtube.com/vi/5eUSmJvK8WA/0.jpg\" \nalt=\"Flax from Zero to Hero!\" width=\"480\" height=\"360\" border=\"10\" /\u003e\u003c/a\u003e\n\u003c/p\u003e\n\n---\n\n### Tutorial #5 (coming up): Machine Learning with Haiku - From Zero to Hero\n\ntodo\n\n## Other useful content\n\nAside from the [official docs](https://jax.readthedocs.io/) here are some resources that helped me.\n\n### Videos\n\n* [Introduction to JAX](https://www.youtube.com/watch?v=0mVmRHMaOJ4\u0026ab_channel=GoogleCloudTech) (gives a very high-level overview)\n* [JAX: Accelerated Machine Learning Research | SciPy 2020 | VanderPlas](https://www.youtube.com/watch?v=z-WSrQDXkuM\u0026ab_channel=Enthought) (many more details)\n* [NeurIPS 2020: JAX Ecosystem Meetup](https://www.youtube.com/watch?v=iDxJxIyzSiM\u0026t=1s\u0026ab_channel=DeepMind) (DeepMind team about the ecosystem of libs around JAX)\n* [Introduction to JAX for Machine Learning and More](https://www.youtube.com/watch?v=QkmKfzxbCLQ\u0026ab_channel=UWaterlooDataScience) (nice, hands-on workshop)\n* [Day 1 Talks: JAX, Flax \u0026 Transformers | HuggingFace](https://www.youtube.com/watch?v=fuAyUQcVzTY\u0026ab_channel=HuggingFace) (all 4 talks are good)\n* [Day 2 Talks: JAX, Flax \u0026 Transformers | HuggingFace](https://www.youtube.com/watch?v=__eG63ZP_5g\u0026ab_channel=HuggingFace) (only the first 2 talks are relevant)\n\n### Blogs\n\n* [Using JAX to accelerate our research | DeepMind](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research) (similar info as the NeuroIPS 2020 video)\n* [You don't know JAX | Colin Raffel](https://colinraffel.com/blog/you-don-t-know-jax.html)\n\n## Acknowledgements\n\n* The notebooks were heavily inspired by the official [JAX](https://jax.readthedocs.io/), [Flax](https://flax.readthedocs.io/en/latest/), and [Haiku](https://dm-haiku.readthedocs.io/en/latest/) docs.\n\n## Citation\n\nIf you find this content useful, please cite the following:\n\n```\n@misc{Gordic2021GetStartedWithJAX,\n  author = {Gordić, Aleksa},\n  title = {Get started with JAX},\n  year = {2021},\n  publisher = {GitHub},\n  journal = {GitHub repository},\n  howpublished = {\\url{https://github.com/gordicaleksa/get-started-with-JAX}},\n}\n```\n\n## Connect With Me\n\nIf you'd love to have some more AI-related content in your life :nerd_face:, consider:\n* Subscribing to my YouTube channel [The AI Epiphany](https://www.youtube.com/c/TheAiEpiphany) :bell:\n* Follow me on [LinkedIn](https://www.linkedin.com/in/aleksagordic/) and [Twitter](https://twitter.com/gordic_aleksa) :bulb:\n* Follow me on [Medium](https://gordicaleksa.medium.com/) :books: :heart:\n* Join the [Discord](https://discord.gg/peBrCpheKE) community! :family:\n\n## Licence\n\n[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://github.com/gordicaleksa/get-started-with-JAX/blob/master/LICENCE)","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fgordicaleksa%2Fget-started-with-jax","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fgordicaleksa%2Fget-started-with-jax","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fgordicaleksa%2Fget-started-with-jax/lists"}