{"id":26075910,"url":"https://github.com/jecampagne/jaxtutos","last_synced_at":"2025-04-11T21:21:43.332Z","repository":{"id":42048231,"uuid":"480707234","full_name":"jecampagne/JaxTutos","owner":"jecampagne","description":"JAX Tutorial notebooks : basics, crash \u0026 tips, usage of optax/JaxOptim/Numpyro","archived":false,"fork":false,"pushed_at":"2025-02-25T13:54:21.000Z","size":37369,"stargazers_count":14,"open_issues_count":0,"forks_count":2,"subscribers_count":3,"default_branch":"main","last_synced_at":"2025-04-11T21:21:36.039Z","etag":null,"topics":["autodifferentiation","jax","jaxoptim","jit-compilation","numpy","numpyro","scipy","tutorial-demos"],"latest_commit_sha":null,"homepage":"","language":"Jupyter Notebook","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"ecl-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/jecampagne.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":null,"code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2022-04-12T07:50:36.000Z","updated_at":"2025-02-25T13:54:25.000Z","dependencies_parsed_at":"2024-05-23T08:56:32.668Z","dependency_job_id":null,"html_url":"https://github.com/jecampagne/JaxTutos","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/jecampagne%2FJaxTutos","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/jecampagne%2FJaxTutos/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/jecampagne%2FJaxTutos/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/jecampagne%2FJaxTutos/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/jecampagne","download_url":"https://codeload.github.com/jecampagne/JaxTutos/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":248480420,"owners_count":21110939,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["autodifferentiation","jax","jaxoptim","jit-compilation","numpy","numpyro","scipy","tutorial-demos"],"created_at":"2025-03-09T01:35:42.262Z","updated_at":"2025-04-11T21:21:43.297Z","avatar_url":"https://github.com/jecampagne.png","language":"Jupyter Notebook","funding_links":[],"categories":[],"sub_categories":[],"readme":"# JaxTutos\nThis repository provides some notebooks to learn JAX (basics and advenced) and use some libraries such as JaxOptim/Numpyro/...\n\nThis work was partily performed using resources from GENCI–IDRIS (Grant 2024-[AD010413957R1]).\n\n# Questions:  \n- Why JAX?: You need Auto-differention first and want a code accelerated ready on CPU/GPU/TPU devices, you probably already know a bit of Python.\n- Does my code will be scalable? Gemini (ie. the Google ChatGPT alter-ego) is coded in JAX so I guess you will find good framework to get your use-case working nicely.\n\n# Exchanges: \n- To discuss you can use the `Discussions` menu\n- To suggest new notebooks, code changes and/or report bugs use `Issues`.\n\n# Here the list of Tutos in this repo:\n## A tour of some basics\n- [JAX_Cophy_tuto.ipynb](./JAX_Cophy_tuto.ipynb) : a Tuto on JAX basics given at GDR Cophy 18–20 nov. 2024 IAP)\n- [JAX_get_started.ipynb](./JAX_get_started.ipynb) : get a flavour of the coding and exemple of auto-diff\n- [JAX_AutoDiff_UserCode.ipynb](./JAX_AutoDiff_UserCode.ipynb) : more on usage of auto diff in  real user-case \"integration methods\"  \n- [JIT_fractals.ipynb](./JIT_fractals.ipynb) : **(GPU better)** with some fractal images production discover some control flow jax.lax functions and nested vmap\n- [JAX_control_flow.ipynb](./JAX_control_flow.ipynb): jax.lax control flow (fori_loop/scan/while_loop, cond). Some \"crashes\" are analysed.\n- [JAX_exo_sum_image_patches.ipynb](./JAX_exo_sum_image_patches.ipynb): Exercice: sum patches of identical size from a 2D image. Experience the compilation/execution times differences of different methods on CPU and GPU (if possible).\n- [JAX-MultiGPus.ipynb](./JAX-MultiGPus.ipynb) : **(4 GPUs)*** (eg. on Jean Zay jupytyterHub plateform) use the \"data sharding module\" to distribute arrays and perform parallelization (2D image productions: simple 2d function and revisit of Julia set from `JIT_fractals.ipynb`.\n## More advanced topics:\nDesigned for people with OO thinking (C++/Python), and/or with in mind  to existing code to transform into JAX. Based on real use case I experienced. This is more advenced and technical but with with \"crashes\" analysed\n- [JAX_JIT_in_class.ipynb](./JAX_JIT_in_class.ipynb): how to use JIT for class methods (as opposed to JIT for an isolated function). \n- [JAX_static_traced_var_func.ipynb](./JAX_static_traced_var_func.ipynb): why and when one needs to use pure Numpy function to make JIT ok\n- [JAX_PyTree_initialisation.ipynb](./JAX_PyTree_initialisation.ipynb): how to perform variable initilisation in a class\n## Using JAX \u0026 some thrid party libraries for real job\n- [JAX_jaxopt_optax.ipynb](./JAX_jaxopt_optax.ipynb): some use of JaxOptim \u0026 Optax libraries\n- [JAX_MC_Sampling.ipynb](./JAX_MC_Sampling.ipynb): pedagogical nb for Monte Carlo Sampling using different techniques. Beyond the math, one experiences the random number generation in JAX which by itself can be a subject. I implement a simple HMC MCMC both in Python and JAX to see the difference.\n- [Numpyro_MC_Sampling.ipynb](./Numpyro_MC_Sampling.ipynb): here we give some simple examples using the Numpyro Probabilistic Programming Language\n- [JAX-GP-regression-piecewise.ipynb](./JAX-GP-regression-piecewise.ipynb): (**Not ready for Calob**) my implementation of Gaussian Processe library to see differences with Sklearn et GPy.\n\n## Other TUTOs (absolutly not exhaustive...)\n- [JAX readthedocs Tutos](https://jax.readthedocs.io/en/latest/tutorials.html): at least up-to-date\n- [Kaggle TF_JAX Tutos (23 Dec. 2021)](https://www.kaggle.com/code/aakashnain/tf-jax-tutorials-part1): Ok, but pb. JAX  v0.2.26\n- [Keras 3 JAX Backend guide](https://keras.io/guides/): jax==0.4.20 \n\n# Other JAX librairies: \n- Have a look at  [awesome-jax](https://project-awesome.org/n2cholas/awesome-jax)\n- More Cosmo-centred:\n   - [JaxPM](https://github.com/DifferentiableUniverseInitiative/JaxPM): JAX-powered Cosmological Particle-Mesh N-body Solver\n   - [S2FFT](http://www.jasonmcewen.org/project/s2fft/): JAX package for computing Fourier transforms on the sphere and rotation group\n   - [JAX-Cosmo](https://github.com/DifferentiableUniverseInitiative/jax_cosmo): a differentiable cosmology library in JAX\n   - [JAX-GalSim](https://github.com/GalSim-developers/JAX-GalSim): JAX version (paper in draft version) of the C++ Galsim code (GalSim is open-source software for simulating images of astronomical objects (stars, galaxies) in a variety of ways)\n   - [CosmoPower-JAX](https://github.com/dpiras/cosmopower-jax): example of cosmological power spectra emulator in a differentiable way\n   - DISCO-DJ I: a differentiable Einstein-Boltzmann solver for cosmology ([here](https://arxiv.org/abs/2311.03291)): not yet public repo.\n- and many others concerning for instance Simulation Based Inference...\n\n\n## Installation (it depends on your local environment)\nMost of the nbs are running on Colab. (JAX 0.4.2x) \n\nIf you want an environement Conda `JaxTutos` (but this is not garanteed to work due to the local \u0026 specific cuda library to be used for the GPU-based nb)\n```\nconda create -n JaxTutos python [\u003e= 3.8]\nconda activate JaxTutos\npip install --upgrade \"jax[cuda]==\u003cXYZ\u003e\" -f https://storage.googleapis.com/jax-releases/jax_releases.html\npip install numpyro\u003e=0.10.1\npip install jaxopt\u003e+0.6\npip install optax\u003e=0.1.4\npip install corner\u003e=2.2.1\npip install arviz\u003e=0.11.4\npip install matplotlib_inline\npip install seaborn\u003e=0.12.2\n```\n\nNotice that starting JAX v0.4.30 the install changes: see [CHANGELOG](https://github.com/google/jax/blob/main/CHANGELOG.md) \n\n# Some Docs\n- JAX: https://jax.readthedocs.io\n- numpy : https://numpy.org/doc/stable/reference/index.html\n- Numpyro : https://num.pyro.ai/en/stable/getting_started.html#what-is-numpyro\n- JaxOpt: https://jaxopt.github.io/stable/\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fjecampagne%2Fjaxtutos","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fjecampagne%2Fjaxtutos","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fjecampagne%2Fjaxtutos/lists"}