https://github.com/yuneg11/neural-process-family
Neural Process implementations in JAX and PyTorch
https://github.com/yuneg11/neural-process-family
jax neural-processes pytorch
Last synced: 2 months ago
JSON representation
Neural Process implementations in JAX and PyTorch
- Host: GitHub
- URL: https://github.com/yuneg11/neural-process-family
- Owner: yuneg11
- License: mit
- Created: 2021-03-18T05:18:32.000Z (about 4 years ago)
- Default Branch: master
- Last Pushed: 2022-11-07T15:46:47.000Z (over 2 years ago)
- Last Synced: 2024-09-19T13:21:27.539Z (9 months ago)
- Topics: jax, neural-processes, pytorch
- Language: Jupyter Notebook
- Homepage:
- Size: 1.51 MB
- Stars: 5
- Watchers: 2
- Forks: 2
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Neural Process Family
**This library is in the early stages of development.**
**PyTorch versions are not yet supported.**
## Installation
You can choose between the following installation methods:
### NPF as a library
```bash
# From PyPI (recommended)
pip install np-family# Latest version (from current branch; dev-v4)
pip install git+https://github.com/yuneg11/Neural-Process-Family@dev-v4# Specific release (from tag; 0.0.1.dev0)
pip install git+https://github.com/yuneg11/[email protected]
```Then, you can use the library as follows:
```python
from npf.jax.models import CNPcnp = CNP(y_dim=1)
```You should handle other logics (include train, evaluation, etc...)
### NPF as an experiment framework
```bash
# Dependencies
pip install rich nxcl==0.0.3.dev3## And ML frameworks (JAX, PyTorch)
# ex) pip install jax# Clone the repository
git clone https://github.com/yuneg11/Neural-Process-Family npf
cd npf
```Then, you can run the experiment, for example:
```bash
python scripts/jax/train.py -f configs/gp/rbf/inf/anp.yaml -lr 0.0001 --model.train_kwargs.num_latents 30
```The output will be saved under `outs/` directory.
Details will be added in the future.## Download or build datasets
```bash
python -m npf.jax.data.save \
--root \
--dataset
```- ``: The root path to save dataset. Default: `./datasets/`
- ``: The name of the dataset to save. See below sections for available datasets.### Image datasets
You should install `torch` and `torchvision` to download the datastes.
You can find the details in the [download page](https://pytorch.org/get-started/locally/).For example,
```bash
# CUDA 11.3
conda install pytorch torchvision cudatoolkit=11.3 -c pytorch
```#### Available datasets
- MNIST: `mnist`
- CIFAR10: `cifar10`
- CIFAR100: `cifar100`
- CelebA: `celeba`
- SVHN: `svhn`### Sim2Real datasets
You should install `numba` and `wget` to simulate or download the datasets.
```bash
pip install numba wget
```#### Available datasets
- Lotka Volterra: `lotka_volterra`
TODO: See `npf.jax.data.save:save_lotka_volterra` for more detailed options.
## Models
- **CNP**: Conditional Neural Process
- **NP**: Neural Process
- **CANP**: Conditional Attentive Neural Process
- **ANP**: Attentive Neural Process
- **BNP**: Bootstrapping Neural Process
- **BANP**: Bootstrapping Attentive Neural Process
- **NeuBNP**: Neural Bootstrapping Neural Process
- **NeuBANP**: Neural Bootstrapping Attentive Neural Process
- **ConvCNP**: Convolutional Conditional Neural Process
- **ConvNP**: Convolutional Neural Process## Scripts
### Train
```bash
python scripts/jax/train.py -f [additional-options]
```You can use your own config file or use the provided config files in the `configs` directory.
For example, the following command will train a CNP model with learning rate of `0.0001` for `100` epochs:```bash
python scripts/jax/train.py -f configs/gp/rbf/inf/anp.yaml \
-lr 0.0001 \
--train.num_epochs 100
```You can see the help of the config file by using the following command:
```bash
python scripts/jax/train.py -f --help
```### Test
```bash
# From a trained model directory
python scripts/jax/test.py -d [additional-options]# From a new config file and a trained model checkpoint
python scripts/jax/test.py -f -c [additional-options]
```You can directly test the trained model by specifying the output directory.
For example:```bash
python scripts/jax/test.py -d outs/CNP/Train/RBF/Inf/220704-181313-vweh
```where `outs/CNP/Train/RBF/Inf/220704-181313-vweh` is the output directory of the trained model.
You can also replace or add the test-specific configs from the config file using the `-tf / --test-config-file` option.
For example:```bash
python scripts/jax/test.py -d outs/CNP/Train/RBF/Inf/220704-181313-vweh \
-tf configs/gp/robust/matern.yaml
```### Test Bayesian optimization
```bash
# From a trained model directory
python scripts/jax/test_bo.py -d [additional-options]# From a new config file and a trained model checkpoint
python scripts/jax/test_bo.py -f -c [additional-options]
```Similar to above the `test` script, you can directly test the trained model by specifying the output directory.
For example:```bash
python scripts/jax/test_bo.py -d outs/CNP/Train/RBF/Inf/220704-181313-vweh
```You can also replace or add the test-specific configs from the config file using the `-bf / --bo-config-file` option.
For example:```bash
python scripts/jax/test.py -d outs/CNP/Train/RBF/Inf/220704-181313-vweh \
-bf configs/gp/rbf/bo_config.yaml
```
## Appendix
## Datasets
1. 1D regression (`x`: `[B, P, 1]`, `y`: `[B, P, 1]`, `mask`: `[B, P]`)
- Gaussian processes, etc...2. 2D Image (`x`: `[B, P, P, 2]`, `y`: `[B, P, P, (1 or 3)]`, `mask`: `[B, P, P]`)
- Image completion, super resolution, etc...3. Bayesian optimization (`x`: `[B, P, D]`, `y`: `[B, P, 1]`, `mask`: `[B, P]`)
## Dimension rule
- `x`: `[batch, *data_specific_dims, data_dim]`
- `y`: `[batch, *data_specific_dims, data_dim]`
- `mask`: `[batch, *data_specific_dims]`
- `outs`: `[batch, *model_specific_dims, *data_specific_dims, data_dim]`### Examples
1. At `CNP` 1D regression:
- `x`: `[batch, point, 1]`
- `y`: `[batch, point, 1]`
- `mask`: `[batch, point]`
- `outs`: `[batch, point, 1]`2. At `NP` 1D regression:
- `x`: `[batch, point, 1]`
- `y`: `[batch, point, 1]`
- `mask`: `[batch, point]`
- `outs`: `[batch, latent, point, 1]`3. At `CNP` 2D image regression:
- `x`: `[batch, height, width, 2]`
- `y`: `[batch, height, width, 1 or 3]`
- `mask`: `[batch, height, width]`
- `outs`: `[batch, height, width, 1 or 3]`4. At `NP` 2D image regression:
- `x`: `[batch, height, width, 2]`
- `y`: `[batch, height, width, 1 or 3]`
- `mask`: `[batch, height, width]`
- `outs`: `[batch, latent, height, width, 1 or 3]`5. At `BNP` 1D regression:
- `x`: `[batch, point, 1]`
- `y`: `[batch, point, 1]`
- `mask`: `[batch, point]`
- `outs`: `[batch, sample, point, 1]`5. At `BNP` 2D image regression:
- `x`: `[batch, height, width, 2]`
- `y`: `[batch, height, width, 1 or 3]`
- `mask`: `[batch, height, width]`
- `outs`: `[batch, sample, height, width, 1 or 3]`
## Acknowledgements
- We used Cloud TPUs supported by Google’s [TPU Research Cloud (TRC)](https://sites.research.google/trc/about/).
- Synthetic GP dataset codes based on [juho-lee/bnp](https://github.com/juho-lee/bnp).
- SetConv modules based on [wesselb/gabriel-convcnp](https://github.com/wesselb/gabriel-convcnp).