https://github.com/lucidrains/progen
Implementation and replication of ProGen, Language Modeling for Protein Generation, in Jax
https://github.com/lucidrains/progen
artificial-intelligence deep-learning proteins
Last synced: about 2 months ago
JSON representation
Implementation and replication of ProGen, Language Modeling for Protein Generation, in Jax
- Host: GitHub
- URL: https://github.com/lucidrains/progen
- Owner: lucidrains
- License: mit
- Created: 2021-06-09T14:44:17.000Z (over 4 years ago)
- Default Branch: main
- Last Pushed: 2021-09-08T20:28:10.000Z (about 4 years ago)
- Last Synced: 2024-12-09T18:11:46.041Z (10 months ago)
- Topics: artificial-intelligence, deep-learning, proteins
- Language: Python
- Homepage:
- Size: 204 KB
- Stars: 111
- Watchers: 8
- Forks: 17
- Open Issues: 3
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
## ProGen - (wip)
Implementation and replication of ProGen, Language Modeling for Protein Generation, in Pytorch and Jax (the weights will be made easily transferrable between the two). You can think of this as GPT for proteins sequences.
## Requirements
We are going to use Poetry for managing the dependencies for this project. So first install it using the one-liner bash command.
Next, git clone the project and install the dependencies
```bash
$ git clone git@github.com:lucidrains/progen
$ cd progen
$ poetry install
```For training on GPUs, you may need to rerun pip install with the correct CUDA version. You can follow the instructions here
```bash
# ex. CUDA 11.1
$ pip install --upgrade "jax[cuda111]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
```For running any scripts, you'll notice that it will always be prepended with `poetry run`
## Usage
```python
from jax import random
from haiku import PRNGSequence
from progen_transformer import ProGenmodel = ProGen(
num_tokens = 256,
dim = 512,
seq_len = 1024,
window_size = 256, # local attention window size
depth = 12, # depth
heads = 8, # attention heads
dim_head = 64, # dimension per head
ff_glu = True, # use GLU in feedforward, from Noam's paper
global_mlp_depth = 2 # last N global gmlp layers
)rng = PRNGSequence(42)
seq = random.randint(next(rng), (1024,), 0, 256)params = model.init(next(rng), seq)
logits = model.apply(params, next(rng), seq) # (1024, 256)
```## Training
Download Uniref50 from UniProt and place `uniref50.fasta` in the root directory
```bash
$ poetry run python generate_data.py
```You should see a lot of green if everything succeeds. Then
```bash
$ poetry run python train.py
```By default, the script will checkpoint and resume automatically, but if you wish to clear your progress and restart, just add a `--new` flag
```bash
$ poetry run python train.py --new
```Model checkpoints will be saved periodically to `./ckpts`
Finally, to sample from your checkpoint, just do
```bash
$ poetry run python sample.py
```You can pass a prime with `--prime`. You can either pass the annotations, followed by `#`, to get the generated sequence, or pass the sequence (also followed by `#`) and get the generated annotations
```bash
$ poetry run python sample.py --prime "[Tax=Mammalia] #"
```## Mixed Precision
To use mixed precision training, you'll need to install the latest Haiku with the following command
```bash
$ pip install git+https://github.com/deepmind/dm-haiku
```Then make sure to set the `--mixed_precision` flag when invoking the training script
```bash
$ poetry run python train.py --mixed_precision
```## Todo
- [ ] model parallelism with pjit
- [ ] join in GO annotations with pandas dataframe
- [ ] setup annotation -> template string system, all configuration driven, find easy way to test. offer two types of annotations, one parsed from uniref descriptions, the other from GO annotation presence
- [ ] add multiple data sources (check out trembl)
- [ ] when sampling, prime with entire sequence prior to the pound sign (intersection of sequence and annotation)
- [ ] utilize all cores when processing data
- [ ] save all training settings in the checkpoints too
- [x] bfloat16 on xla
- [x] resume from correct place in tfrecord even if batch size is changed inbetween runs, display number of sequences processed
- [x] train compressed gzip tfrecords from google cloud storage path
- [x] remove tfrecord package and just use tfrecordwriter with gzip
- [x] generate validation tfrecords
- [x] checkpoint and resume from a google cloud storage path
- [x] use jinja2 for wandb html sample logging
- [x] manage experimental tracker state, and also allow ability to turn it off by piping to noop
- [x] add a confirmation before clearing a folder for --new run
- [x] engineer mask in cross entropy loss so that padding can be reused as end-of-string token
- [x] flip seq # annotation order with prob set in config
- [x] keep N last checkpoints## Acknowledgements
Many thanks goes out to Ben Wang, who showed this type of large-scale training can be achieved with GPT-J
## Citations
```bibtex
@misc{madani2020progen,
title = {ProGen: Language Modeling for Protein Generation},
author = {Ali Madani and Bryan McCann and Nikhil Naik and Nitish Shirish Keskar and Namrata Anand and Raphael R. Eguchi and Po-Ssu Huang and Richard Socher},
year = {2020},
eprint = {2004.03497},
archivePrefix = {arXiv},
primaryClass = {q-bio.BM}
}
``````bibtex
@misc{su2021roformer,
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
year = {2021},
eprint = {2104.09864},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}
``````bibtex
@misc{shazeer2020glu,
title = {GLU Variants Improve Transformer},
author = {Noam Shazeer},
year = {2020},
url = {https://arxiv.org/abs/2002.05202}
}
```