Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/google-research/perceiver-ar
https://github.com/google-research/perceiver-ar
Last synced: 3 months ago
JSON representation
- Host: GitHub
- URL: https://github.com/google-research/perceiver-ar
- Owner: google-research
- License: apache-2.0
- Created: 2022-05-26T17:28:26.000Z (over 2 years ago)
- Default Branch: main
- Last Pushed: 2024-05-06T18:36:00.000Z (6 months ago)
- Last Synced: 2024-05-09T17:12:11.777Z (6 months ago)
- Language: Python
- Size: 73.2 KB
- Stars: 227
- Watchers: 12
- Forks: 21
- Open Issues: 24
-
Metadata Files:
- Readme: README.md
- Contributing: CONTRIBUTING.md
- License: LICENSE
Awesome Lists containing this project
README
# Perceiver AR
Perceiver AR is an autoregressive, modality-agnostic architecture which uses
cross-attention to map long-range inputs to a small number of latents while also
maintaining end-to-end causal masking. Perceiver AR can directly attend to over
a hundred thousand tokens, enabling practical long-context density estimation
without the need for hand-crafted sparsity patterns or memory mechanisms.For more details, see our ICML 2022 paper: https://arxiv.org/abs/2202.07765
An implementation is also available for [T5X/Flaxformer](
https://github.com/google/flaxformer/tree/main/flaxformer/architectures/perceiver_ar).## Setup
First, install dependencies following these instructions:
1. Create a virtual env: `python3 -m venv ~/.venv/perceiver-ar`
2. Switch to the virtual env: `source ~/.venv/perceiver-ar/bin/activate`
3. Follow instructions for installing JAX on your platform:
https://github.com/google/jax#installation
4. Install other dependencies: `pip install -r requirements.txt`## Training
As an example of the model, a 32-position version of the Copy Task from our
paper can be trained using only a local CPU.```
PYTHONPATH=.::$PYTHONPATH python perceiver_ar/experiment.py \
--config=perceiver_ar/experiment.py:random_mirrored_32
```By default, checkpoints and events will be saved to `/tmp/perceiver_ar`.
Training metrics will be periodically written to Tensorboard event files which
can be viewed using:```
tensorboard --logdir /tmp/perceiver_ar/
```During training, use Ctrl+C to save a checkpoint and Ctrl+\ to save a checkpoint
and exit.## Evaluation
To evaluate the latest saved checkpoint:
```
CHECKPOINTS="/tmp/perceiver_ar"
LATEST_CHECKPOINT="${CHECKPOINTS}/models/latest/$(ls -tr ${CHECKPOINTS}/models/latest/ | tail -n 1)"
echo "Evaluating ${LATEST_CHECKPOINT}"
PYTHONPATH=.::$PYTHONPATH python perceiver_ar/experiment.py \
--config=perceiver_ar/experiment.py:random_mirrored_32 \
--jaxline_mode=eval \
--config.one_off_evaluate=True \
--config.restore_path="${LATEST_CHECKPOINT}"
```Results will be written to the console and can also be viewed from Tensorboard.
## Inference
To run inference in a local Jupyter notebook:
```
jupyter notebook
```Load `inference.ipynb` and follow the instructions in the notebook.
### Pretrained Copy Task
The notebook also supports loading a pretrained checkpoint for the
131k-position copy task used in our paper. This model is fairly large,
so inferring more than a few positions will likely require a large
accelerator. The notebook has been tested to run on a GCP
[TPU VM](https://cloud.google.com/tpu/docs/users-guide-tpu-vm) using a
TPU v3-8.## Unit Tests
To run all unit tests:
```
pytest
```## Disclaimer
This is not an officially supported Google product.