https://github.com/nathom/jax-profiler
Memory profiler for JAX
https://github.com/nathom/jax-profiler
Last synced: 5 months ago
JSON representation
Memory profiler for JAX
- Host: GitHub
- URL: https://github.com/nathom/jax-profiler
- Owner: nathom
- License: mit
- Created: 2024-08-11T18:36:01.000Z (almost 2 years ago)
- Default Branch: main
- Last Pushed: 2024-08-12T02:21:55.000Z (almost 2 years ago)
- Last Synced: 2025-12-26T11:50:39.317Z (6 months ago)
- Language: Python
- Size: 5.86 KB
- Stars: 2
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# jax-profiler
Memory profiler for JAX
## Usage
Create profiler logs
```python
from jaxprof import JaxProfiler
profiler = JaxProfiler()
def some_jax_code():
...
profiler.capture()
...
```
or run it in the background
```python
profiler.capture_in_background()
```
Generate plots from profiler logs
```bash
python jaxprof.py --help
```