https://github.com/zombie-einstein/jaxpr-viz
Jaxpr Visualisation Tool
https://github.com/zombie-einstein/jaxpr-viz
computation-graph graph graphviz jax visualization
Last synced: 7 days ago
JSON representation
Jaxpr Visualisation Tool
- Host: GitHub
- URL: https://github.com/zombie-einstein/jaxpr-viz
- Owner: zombie-einstein
- License: mit
- Created: 2023-08-19T17:36:16.000Z (over 2 years ago)
- Default Branch: main
- Last Pushed: 2024-12-22T14:57:01.000Z (about 1 year ago)
- Last Synced: 2025-08-29T13:43:53.192Z (5 months ago)
- Topics: computation-graph, graph, graphviz, jax, visualization
- Language: Python
- Homepage:
- Size: 423 KB
- Stars: 28
- Watchers: 2
- Forks: 1
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Jaxpr-Viz
JAX Computation Graph Visualisation Tool
JAX has built-in functionality to visualise the
HLO graph generated by JAX, but I've found this rather
low-level for some use-cases.
The intention of this package is to visualise how
sub-functions are connected in JAX programs. It does
this by converting the [JaxPr](https://jax.readthedocs.io/en/latest/jaxpr.html)
representation into a pydot graph. See [here](.github/docs/gallery.md)
for examples.
> **NOTE:** This project is still at an early stage and may not
> support all JAX functionality (or permutations thereof). If you spot
> some strange behaviour please create a [Github issue](https://github.com/zombie-einstein/jaxpr-viz/issues).
## Installation
Install with pip:
```bash
pip install jpviz
```
Dependent on your system you may also need to install [Graphviz](https://www.graphviz.org/)
## Usage
Jaxpr-viz can be used to visualise jit compiled (and nested)
functions. It wraps jit compiled functions, which when called
with concrete values returns a [pydot](https://github.com/pydot/pydot)
graph.
For example this simple computation graph
```python
import jax
import jax.numpy as jnp
import jpviz
@jax.jit
def foo(x):
return 2 * x
@jax.jit
def bar(x):
x = foo(x)
return x - 1
# Wrap function and call with concrete arguments
# here dot_graph is a pydot object
dot_graph = jpviz.draw(bar)(jnp.arange(10))
# This renders the graph to a png file
dot_graph.write_png("computation_graph.png")
```
produces this image

Pydot has a number of options for rendering graphs, see
[here](https://github.com/pydot/pydot#output).
> **NOTE:** For sub-functions to show as nodes/sub-graphs they
> need to be marked with `@jax.jit`, otherwise they will just
> merged into thir parent graph.
### Jupyter Notebook
To show the rendered graph in a jupyter notebook you can use the
helper function `view_pydot`
```python
...
dot_graph = jpviz.draw(bar)(jnp.arange(10))
jpviz.view_pydot(dot)
```
### Visualisation Options
#### Collapse Nodes
By default, functions that are composed of only primitive functions
are collapsed into a single node (like `foo` in the above example).
The full computation graph can be rendered using the `collapse_primitives`
flag, setting it to `False` in the above example
```python
...
dot_graph = jpviz.draw(bar, collapse_primitives=False)(jnp.arange(10))
...
```
produces

#### Show Types
By default, type information is included in the node labels, this
can be hidden using the `show_avals` flag, setting it to `False`
```python
...
dot_graph = jpviz.draw(bar, show_avals=False)(jnp.arange(10))
...
```
produces

> **NOTE:** The labels of the nodes don't currently correspond
> to argument/variable names in the original Python code. Since
> JAX unpacks arguments/outputs to tuples they do correspond
> to the positioning of arguments and outputs.
## Examples
See [here](.github/docs/gallery.md) for more examples of rendered computation graphs.
## Developers
Developer notes can be found [here](.github/docs/developers.md).