Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/0x7o/retro-transformer
Easy-to-use Retrieval-Enhanced Transformer implementation
https://github.com/0x7o/retro-transformer
attention-mechanism deep-learning language-model pytorch retrieval transformers
Last synced: 4 days ago
JSON representation
Easy-to-use Retrieval-Enhanced Transformer implementation
- Host: GitHub
- URL: https://github.com/0x7o/retro-transformer
- Owner: 0x7o
- License: apache-2.0
- Created: 2022-07-30T05:21:56.000Z (over 2 years ago)
- Default Branch: main
- Last Pushed: 2022-09-30T03:34:34.000Z (about 2 years ago)
- Last Synced: 2024-11-15T08:05:24.466Z (5 days ago)
- Topics: attention-mechanism, deep-learning, language-model, pytorch, retrieval, transformers
- Language: Python
- Homepage:
- Size: 586 KB
- Stars: 9
- Watchers: 2
- Forks: 4
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
![RETRO](data/RETRO.png)
# Retrieval-Enhanced TransformerEasy-to-use [Retro](https://arxiv.org/abs/2112.04426) implementation in PyTorch.
This code based on [labml.ai](https://nn.labml.ai/transformers/retro/index.html) and [accelerate](https://github.com/huggingface/accelerate) for light inference and training on CPUs, GPUs, TPUs.
```python
from retro_transformer.bert import BERTForChunkEmbeddings
from retro_transformer.tools.database import build_database, RetroIndex
from retro_transformer.tools.dataset import build_dataset
from retro_transformer.model import RetroModel, NearestNeighborEncoder
from retro_transformer.tools.train import trainchunk_len = 16
d_model = 128
d_ff = 512
n_heads = 16
d_k = 16
n_layers = 16
workspace = './workspace'
text_file = 'text.txt'bert = BERTForChunkEmbeddings('bert-base-uncased', 'cuda')
index = RetroIndex(workspace, chunk_len, bert=bert)build_database(workspace, text_file, bert=bert, chunk_len=chunk_len)
num_tokens = build_dataset(workspace, text_file, chunk_len=chunk_len, index=index)nearest_neighbor_encoder = NearestNeighborEncoder(chunk_len=chunk_len, n_layers=n_layers,
d_model=d_model, d_ff=d_ff, n_heads=n_heads,
d_k=d_k, ca_layers={3})model = RetroModel(n_vocab=num_tokens, d_model=d_model, n_layers=n_layers, chunk_len=chunk_len,
n_heads=n_heads, d_k=d_k, d_ff=d_ff, encoder=nearest_neighbor_encoder, ca_layers={3, 5})train(model, workspace, text_file, chunk_len=chunk_len, d_model=d_model)
```