https://github.com/jaketae/graph-neural-ode
Graph neural ordinary differential equations
https://github.com/jaketae/graph-neural-ode
gnn graph-neural-networks neural-ode ordinary-differential-equations pytorch
Last synced: 8 months ago
JSON representation
Graph neural ordinary differential equations
- Host: GitHub
- URL: https://github.com/jaketae/graph-neural-ode
- Owner: jaketae
- License: other
- Created: 2023-12-08T04:31:50.000Z (almost 2 years ago)
- Default Branch: master
- Last Pushed: 2023-12-10T08:00:56.000Z (almost 2 years ago)
- Last Synced: 2024-11-30T19:34:02.854Z (10 months ago)
- Topics: gnn, graph-neural-networks, neural-ode, ordinary-differential-equations, pytorch
- Language: Python
- Homepage:
- Size: 1.01 MB
- Stars: 5
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
## Graph Neural Differential Equations
This repository contains code for [Graph Neural ODE++](https://drive.google.com/file/d/1Rjn1X87AvP62rmBsImnTcj74V1Mc6IS_/view?usp=sharing). This work was completed as part of CPSC 483: Deep Learning on Graph-Structured Data.
## Abstract
> We propose Graph Neural ODE++, an improved paradigm for Graph Neural Ordinary Differential Equations (GDEs). Inspired by recent literature in score-based generative models, we explore two different heuristics for training GDEs: linear simplex refinement and consistency modeling. We observe that both methods improve model performance on standard transductive node classification datasets, albeit marginally. Furthermore, we show that there is a direct relationship between training methodology and the behavior of the model at different time steps within the integration window of the ODE.
## Quickstart
1. Clone this repository.
```
$ git clone https://github.com/jaketae/graph-neural-ode/
```2. Create a Python virtual enviroment and install package requirements.
```
$ cd graph-neural-ode
$ python -m venv venv
$ pip install -U pip wheel # update pip
$ pip install -r requirements.txt
```3. Train all 3 models (GDE, GDE++ LSR, GDE++ CM) on the Cora dataset.
```
CUDA_VISIBLE_DEVICES=0 DATASET=cora sh train.sh
```## Training
To train a model, run [`main.py`](main.py). The full list of supported arguments are shown below.
```
$ python main.py --help
usage: main.py [-h] [--name NAME] [--dataset [{cora,citeseer,pubmed}]] [--repeat REPEAT] [--hidden_channels HIDDEN_CHANNELS]
[--steps STEPS] [--dropout DROPOUT] [--atol ATOL] [--rtol RTOL] [--verbose VERBOSE] [--guide | --no-guide]
[--fast | --no-fast] [--adjoint | --no-adjoint] [--seed SEED]
[--solver [{dopri8,dopri5,bosh3,fehlberg2,adaptive_heun,euler,midpoint,heun3,rk4,explicit_adams,implicit_adams,fixed_adams,scipy_solver}]]options:
-h, --help show this help message and exit
--name NAME
--dataset [{cora,citeseer,pubmed}]
--repeat REPEAT
--hidden_channels HIDDEN_CHANNELS
--steps STEPS
--dropout DROPOUT
--atol ATOL
--rtol RTOL
--verbose VERBOSE
--guide, --no-guide
--fast, --no-fast
--adjoint, --no-adjoint
--seed SEED
--solver [{dopri8,dopri5,bosh3,fehlberg2,adaptive_heun,euler,midpoint,heun3,rk4,explicit_adams,implicit_adams,fixed_adams,scipy_solver}]
```The script will report the mean and standard deviation of the test accuracy under `output/results`. The best model checkpoint measured by validation accuracy will be saved under `output/checkpoints`.
## Inference
To evaluate a model checkpoint, run [`inference.py`](inference.py). The full list of supported arguments are shown below.
```
$ python inference.py --help
usage: inference.py [-h] [--name NAME] [--dataset [{cora,citeseer,pubmed}]] [--repeat REPEAT] [--hidden_channels HIDDEN_CHANNELS]
[--steps STEPS] [--dropout DROPOUT] [--atol ATOL] [--rtol RTOL] [--verbose VERBOSE] [--guide | --no-guide]
[--fast | --no-fast] [--adjoint | --no-adjoint] [--seed SEED]
[--solver [{dopri8,dopri5,bosh3,fehlberg2,adaptive_heun,euler,midpoint,heun3,rk4,explicit_adams,implicit_adams,fixed_adams,scipy_solver}]]options:
-h, --help show this help message and exit
--name NAME
--dataset [{cora,citeseer,pubmed}]
--repeat REPEAT
--hidden_channels HIDDEN_CHANNELS
--steps STEPS
--dropout DROPOUT
--atol ATOL
--rtol RTOL
--verbose VERBOSE
--guide, --no-guide
--fast, --no-fast
--adjoint, --no-adjoint
--seed SEED
--solver [{dopri8,dopri5,bosh3,fehlberg2,adaptive_heun,euler,midpoint,heun3,rk4,explicit_adams,implicit_adams,fixed_adams,scipy_solver}]
```The script will automatically locate the checkpoint file based on the `name` and `dataset` arguments.
## License
Released under the [MIT License](LICENSE).