Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/jinxu06/gsubsampling
Reference implementation for "Group Equivariant Subsampling"
https://github.com/jinxu06/gsubsampling
Last synced: 3 months ago
JSON representation
Reference implementation for "Group Equivariant Subsampling"
- Host: GitHub
- URL: https://github.com/jinxu06/gsubsampling
- Owner: jinxu06
- License: apache-2.0
- Created: 2021-06-08T15:05:37.000Z (over 3 years ago)
- Default Branch: main
- Last Pushed: 2021-12-13T22:11:06.000Z (almost 3 years ago)
- Last Synced: 2024-07-04T00:58:34.042Z (4 months ago)
- Language: Python
- Homepage:
- Size: 5.39 MB
- Stars: 16
- Watchers: 2
- Forks: 1
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Group Equivariant Subsampling
This is a reference implementation for [Group Equivariant Subsampling](https://arxiv.org/abs/2106.05886) by Jin Xu, Hyunjik Kim, Tom Rainforth and Yee Whye Teh.
## Dependencies
See `environment.yml`. For [Anaconda](https://docs.anaconda.com/anaconda/install/) users, please create conda environment with `conda env create -f environment.yml`
## Data
For [dSprites](https://github.com/deepmind/dsprites-dataset) and [FashionMNIST](https://github.com/zalandoresearch/fashion-mnist), data will be automatically downloaded and preprocessed before your first run.
For [multi-object datatsets](https://github.com/deepmind/multi_object_datasets) such as Multi-dSprites, please first run
```
python multi_object_datasets/load.py --dataset multi_dsprites --datadir "/tmp"
```## Running the code
### Training autoencoders
Train `ConvAE`, `GConvAE-p4`, `GConvAE-p4m`, `GAE-p1`, `GAE-p4`, `GAE-p4m` on `dSprites`:
```
python main.py hydra.job.name=sample_complexity model=conv_ae run.mode=train data=dsprites data.train_set_size=1600 run.random_seed=1
```
```
python main.py hydra.job.name=sample_complexity model=gconv_ae model.n_channels=21 model.fiber_group='rot_2d' model.n_rot=4 run.mode=train data=dsprites data.train_set_size=1600 run.random_seed=1
```
```
python main.py hydra.job.name=sample_complexity model=gconv_ae model.n_channels=15 model.fiber_group='flip_rot_2d' model.n_rot=4 run.mode=train data=dsprites data.train_set_size=1600 run.random_seed=1
```
```
python main.py hydra.job.name=sample_complexity model=eqv_ae run.mode=train data=dsprites data.train_set_size=1600 run.random_seed=1
```
```
python main.py hydra.job.name=sample_complexity model=eqv_ae model.n_channels=26 model.fiber_group='rot_2d' model.n_rot=4 run.mode=train data=dsprites data.train_set_size=1600 run.random_seed=1
```
```
python main.py hydra.job.name=sample_complexity model=eqv_ae model.n_channels=18 model.fiber_group='flip_rot_2d' model.n_rot=4 model.n_rot=4 run.mode=train data=dsprites data.train_set_size=1600 run.random_seed=1
```
The numbers of channels are rescaled so that the above models have similar number of parameters. To train on `FashionMNIST`, one can simply set `data=fashion_mnist`. To show the progress bar during training, set `run.use_prog_bar=True`.### Evaluate autoencoders
To visualise image reconstructions, set `run.mode=reconstruct`. For example, for `GAE-p1` on `dSprites`,
```
python main.py hydra.job.name=sample_complexity model=eqv_ae run.mode=reconstruct data=dsprites data.train_set_size=1600 run.random_seed=1
```
To evaluate the trained model, set set `run.mode=eval` and run:
```
python main.py hydra.job.name=sample_complexity model=eqv_ae run.mode=eval eval.which_set test data=dsprites data.train_set_size=1600 run.random_seed=1
```### Out-of-distribution experiments
To train autoencoders on constrained data for out-of-distribution experiments, one can run (using `ConvAE` as an example):
```
python main.py hydra.job.name=ood model=conv_ae run.mode=train data=dsprites data.train_set_size=6400 data.constrained_transform="translation_rotation"
```To regenerate the visualisation in the paper, use our python script at `py_scripts/out_of_distribution.py` (coming soon).
### Multi-object experiments
To train MONet baseline, run:
```
python main.py hydra.job.name=compare_to_monet model=monet run.mode=train data=multi_dsprites data.train_set_size=6400 data.batch_size=16 run.max_epochs=1000 run.random_seed=1
```
To train MONet-GAE-p1, run:
```
python main.py hydra.job.name=compare_to_monet model=eqv_monet run.mode=train data=multi_dsprites data.train_set_size=6400 data.batch_size=16 run.max_epochs=1000 run.random_seed=1
```## About this repository
We use [Hydra](https://hydra.cc/) to specify configurations for experiments. " The key feature is the ability to dynamically create a hierarchical configuration by composition and override it through config files and the command line." Our default hydra configurations can be found at `conf/`.
The directory `elm/` contains most of our research code. All the data loaders can be found at `elm/data_loader/`, and all the models can be found at `elm/model/`. Experimental results will be generated at `outputs/`, organised by dates and job names. By default, Logs and checkpoints are directed to `/tmp/log/`, but this can be reconfigured in `conf/config.yaml`.
## Contact
To ask questions about code or report issues, please directly open an issue on github.
To discuss research, please email [email protected]## Acknowledgements
This repository includes code from two previous projects: [GENESIS](https://github.com/applied-ai-lab/genesis) and [Multi_Object_datasets](https://github.com/deepmind/multi_object_datasets). Their original licenses have been included.