https://github.com/patrick-kidger/fasterneuraldiffeq
Code for "'Hey, that's not an ODE:' Faster ODE Adjoints via Seminorms" (ICML 2021)
https://github.com/patrick-kidger/fasterneuraldiffeq
controlled-differential-equations deep-learning deep-neural-networks differential-equations dynamical-systems machine-learning neural-differential-equations numerical-analysis numerical-methods ordinary-differential-equations pytorch
Last synced: 15 days ago
JSON representation
Code for "'Hey, that's not an ODE:' Faster ODE Adjoints via Seminorms" (ICML 2021)
- Host: GitHub
- URL: https://github.com/patrick-kidger/fasterneuraldiffeq
- Owner: patrick-kidger
- License: apache-2.0
- Created: 2020-08-13T08:19:06.000Z (over 4 years ago)
- Default Branch: master
- Last Pushed: 2022-10-19T00:21:18.000Z (over 2 years ago)
- Last Synced: 2025-04-15T08:59:33.064Z (15 days ago)
- Topics: controlled-differential-equations, deep-learning, deep-neural-networks, differential-equations, dynamical-systems, machine-learning, neural-differential-equations, numerical-analysis, numerical-methods, ordinary-differential-equations, pytorch
- Language: Python
- Homepage:
- Size: 635 KB
- Stars: 87
- Watchers: 6
- Forks: 9
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
"Hey, that's not an ODE": Faster ODE Adjoints via Seminorms
(ICML 2021)
[arXiv]
![]()
One simple-to-implement trick dramatically improves the speed at which [Neural ODEs](https://arxiv.org/abs/1806.07366) and [Neural CDEs](https://arxiv.org/abs/2005.08926) can be trained. (As much as doubling the speed.)
Backpropagation through a Neural ODE/CDE can be performed via the "adjoint method", which involves solving another differential equation backwards in time. However it turns out that default numerical solvers are unnecessarily stringent when solving the adjoint equation, and take too many steps, that are too small.
Tweaking things slightly reduces the number of function evaluations on the backward pass **by as much as 62%**. (Exact number will be problem-dependent, of course.)
_[torchdiffeq](https://github.com/rtqichen/torchdiffeq) now supports this feature natively!_
---
## Summary:
If you're using [torchdiffeq](https://github.com/rtqichen/torchdiffeq) (at least version 0.1.0) then replace
```python
import torchdiffeqfunc = ...
y0 = ...
t = ...
torchdiffeq.odeint_adjoint(func=func, y0=y0, t=t)
```
with
```python
import torchdiffeqdef rms_norm(tensor):
return tensor.pow(2).mean().sqrt()def make_norm(state):
state_size = state.numel()
def norm(aug_state):
y = aug_state[1:1 + state_size]
adj_y = aug_state[1 + state_size:1 + 2 * state_size]
return max(rms_norm(y), rms_norm(adj_y))
return normfunc = ...
y0 = ...
t = ...
torchdiffeq.odeint_adjoint(func=func, y0=y0, t=t,
adjoint_options=dict(norm=make_norm(y0)))
```
That's it.## Reproducing experiments
The code for the Neural CDE and Symplectic ODE-Net experiments is available.### Requirements
PyTorch >= 1.6
[torchdiffeq](https://github.com/rtqichen/torchdiffeq) >= 0.1.0
[torchcde](https://github.com/patrick-kidger/torchcde) >= 0.1.0
[torchaudio](pytorch.org/audio/) >= 0.6.0
[sklearn](https://scikit-learn.org/stable/) >= 0.23.1
[gym](https://github.com/openai/gym) >= 0.17.2
[tqdm](https://github.com/tqdm/tqdm) >= 4.47.0In summary:
```bash
conda install pytorch torchaudio -c pytorch
pip install torchdiffeq scikit-learn gym tqdm
pip install git+https://github.com/patrick-kidger/torchcde.git
```### Neural CDEs
```bash
python
>>> import speech_commands
>>> device = 'cuda'
>>> norm = False # don't use our trick
>>> norm = True # use our trick
>>> rtol = 1e-3
>>> atol = 1e-5
>>> results = speech_commands.main(device, norm, rtol, atol)
>>> print(results.keys()) # inspect results object
>>> print(results.test_metrics.accuracy) # query results object
```### Symplectic ODE-Net
```bash
python
>>> import acrobot
>>> device = 'cuda'
>>> norm = False # don't use our trick
>>> norm = True # use our trick
>>> results = acrobot.main(device, norm)
>>> print(results.keys()) # inspect results object
>>> print(results.test_metrics.loss) # query results object
```
---## Citation
```bibtex
@article{kidger2021hey,
author={Kidger, Patrick and Chen, Ricky T. Q. and Lyons, Terry},
title={{``Hey, that's not an ODE'': Faster ODE Adjoints via Seminorms}},
year={2021},
journal={International Conference on Machine Learning}
}
```