https://github.com/patrick-kidger/esm2quinox
An implementation of ESM2 in Equinox+JAX
https://github.com/patrick-kidger/esm2quinox
equinox esm2 jax neural-networks protein
Last synced: 6 months ago
JSON representation
An implementation of ESM2 in Equinox+JAX
- Host: GitHub
- URL: https://github.com/patrick-kidger/esm2quinox
- Owner: patrick-kidger
- License: apache-2.0
- Created: 2024-06-03T08:28:11.000Z (over 1 year ago)
- Default Branch: main
- Last Pushed: 2025-02-11T14:36:16.000Z (8 months ago)
- Last Synced: 2025-04-20T04:15:33.998Z (6 months ago)
- Topics: equinox, esm2, jax, neural-networks, protein
- Language: Python
- Homepage:
- Size: 18.6 KB
- Stars: 25
- Watchers: 2
- Forks: 2
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
- awesome-jax - esm2quinox - An implementation of ESM2 in Equinox+JAX. <img src="https://img.shields.io/github/stars/patrick-kidger/esm2quinox?style=social" align="center"> (Models and Projects / Inactive Libraries)
- awesome-jax - esm2quinox - An implementation of ESM2 in Equinox+JAX. <img src="https://img.shields.io/github/stars/patrick-kidger/esm2quinox?style=social" align="center"> (Models and Projects / Inactive Libraries)
README
ESM2quinox
An implementation of ESM2 in Equinox+JAX
## Installation
```
pip install esm2quinox
```## Public API
See their docstrings for details:
```
esm2quinox
.ESM2
.__init__(self, num_layers: int, embed_size: int, num_heads: int, token_dropout: bool, key: PRNGKeyArray)
.__call__(self, tokens: Int[np.ndarray | jax.Array, " length"]) -> esm2quinox.ESM2Result.ESM2Result
.hidden: Float[Array, "length embed_size"]
.logits: Float[Array, "length alphabet_size"].tokenise(proteins: list[str], length: None | int = None, key: None | PRNGKeyArray = None)
.from_torch(torch_esm2: esm.ESM2) -> esm2quinox.ESM2
```## Quick examples
Load an equivalent pretrained model from PyTorch:
```python
import esm # pip install fair-esm==2.0.0
import esm2quinoxtorch_model, _ = esm.pretrained.esm2_t6_8M_UR50D()
model = esm2quinox.from_torch(torch_model)
```Create a randomly-initialised model:
```python
import esm2quinox
import jax.random as jrkey = jr.key(1337)
model = esm2quinox.ESM2(num_layers=3, embed_size=32, num_heads=2, token_dropout=False, key=key)
```Forward pass (note the model operates on unbatched data):
```python
proteins = esm2quinox.tokenise(["SPIDERMAN", "FOO"])
out = jax.vmap(model)(proteins)
out.hidden # hidden representation from last layer
out.logits # logits for masked positions
```