An open API service indexing awesome lists of open source software.

https://github.com/lucidrains/marge-pytorch

Implementation of Marge, Pre-training via Paraphrasing, in Pytorch
https://github.com/lucidrains/marge-pytorch

artificial-intelligence deep-learning pre-training retrieval transformers

Last synced: 12 months ago
JSON representation

Implementation of Marge, Pre-training via Paraphrasing, in Pytorch

Awesome Lists containing this project

README

          

## Marge - Pre-training via Paraphrasing

Implementation of Marge, Pre-training via Paraphrasing, in Pytorch. It is an alternative to masked language modeling pretraining, where an encoder / decoder attention network learns to reconstruct a target document from a collection of evidence documents.

Update: Three researchers have independently reported that the repository works for them

## Install

```bash
$ pip install marge-pytorch
```

## Usage

```python
import torch
import numpy as np
from torch.utils.data import DataLoader

from marge_pytorch import Marge, TrainingWrapper

# your documents must be tokenized and stored as memmap in the shape (num documents, seq length)

# constants
NUM_DOCS = 10000
SEQ_LEN = 1024
SHAPE = (NUM_DOCS, SEQ_LEN)

# generate mock training data
f = np.memmap('./train.dat', dtype=np.int32, mode='w+', shape=SHAPE)
f[:] = np.random.randint(0, 20000, size=SHAPE)
del f

# generate mock masking data
f = np.memmap('./train.mask.dat', dtype=np.bool, mode='w+', shape=SHAPE)
f[:] = np.full(SHAPE, True)
del f

# instantiate model

model = Marge(
dim = 512,
num_tokens = 20000,
max_seq_len = SEQ_LEN,
enc_depth = 12,
enc_retrieval_depth = 4, # defaults to 4 as in paper (take the CLS token after the 4th layer of the encoder)
enc_heads = 8,
enc_ff_mult = 4,
dec_depth = 12,
dec_heads = 8,
dec_ff_mult = 16, # paper noted that decoder needs to have much bigger feed forward sizes
distill_attn = False, # (experimental) will add, on top of the decoder loss, an auxiliary distillation loss as defined in https://arxiv.org/abs/2012.04584
distill_loss_coef = 1. # weight of distillation auxilliary loss
)

# wrap your model and your documents

trainer = TrainingWrapper(
model,
num_documents = NUM_DOCS,
doc_seq_len = SEQ_LEN,
num_evidence = 4, # number of evidence documents to fetch per target document to construct
reindex_batch_size = 32, # batch size to use when reindexing
documents_memmap_path = './train.dat', # path to the mem-mapped documents
masks_memmap_path = './train.mask.dat', # if None is supplied, will assume all tokens are visible
use_faiss_ann = True # set this to false if you have a low number of documents, and approximate nearest neighbor is not needed
)

# instantiate dataloader

dl = DataLoader(trainer.dataset, batch_size=16)

# now you can train, and use the reindex method on the training wrapper at appropriate intervals

for ind, data in enumerate(dl):
loss = trainer(data)
loss.backward()
# optimizer step and all that

# reindex and precompute knn every 10000 steps, as in paper
if ind > 0 and ind % 10000 == 0:
trainer.reindex()
```

Save your model after much training

```python
torch.save(model, f'./trained-model.pt')
```

## Advanced

If you would like the target and evidence documents to be from different sets, you just have to pass in up to four additional keyword arguments, as shown below.

```python
trainer = TrainingWrapper(
model,
num_documents = NUM_DOCS,
doc_seq_len = SEQ_LEN,
num_evidence = 4,
reindex_batch_size = 32,
documents_memmap_path = './evidence.dat',
masks_memmap_path = './evidence.mask.dat',
num_targets = NUM_TARGETS, # 1. number of target documents, with sequence length the same as the document (evidence)
target_seq_len = SEQ_LEN, # 2. sequence length of target documents
target_memmap_path = './target.dat', # 3. path to target memmap, same as documents (evidence)
target_masks_memmap_path = './target.mask.dat', # 4. path to target mask memmap, same as document masks (evidence)
use_faiss_ann = True
)
```

## Sampling

You can sample from the decoder with the following instructions

```python
# some random evidence from the dataset
# or provide your own in the dimensions (b x num_evidences x seq_len)
*_, evidence, mask = trainer.dataset[0:1]

# assume 1 is start token
prime = torch.tensor([[1.]]).long().cuda()

# supply your own document similarities array (b x num_evidences)
# if not supplied, will default to 1. for all evidence
doc_similarities = torch.ones(evidence.shape[:2]).float().cuda()

# generate sample of length 1024
samples = model.generate(prime, 1024, evidence, mask = mask, similarities = doc_similarities)
```

## Citations

```bibtex
@misc{lewis2020pretraining,
title={Pre-training via Paraphrasing},
author={Mike Lewis and Marjan Ghazvininejad and Gargi Ghosh and Armen Aghajanyan and Sida Wang and Luke Zettlemoyer},
year={2020},
eprint={2006.15020},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```

```bibtex
@misc{komatsuzaki2020current,
title={Current Limitations of Language Models: What You Need is Retrieval},
author={Aran Komatsuzaki},
year={2020},
eprint={2009.06857},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```

```bibtex
@misc{izacard2020distilling,
title={Distilling Knowledge from Reader to Retriever for Question Answering},
author={Gautier Izacard and Edouard Grave},
year={2020},
eprint={2012.04584},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```