https://github.com/lucidrains/coco-lm-pytorch
Implementation of COCO-LM, Correcting and Contrasting Text Sequences for Language Model Pretraining, in Pytorch
https://github.com/lucidrains/coco-lm-pytorch
artificial-intelligence deep-learning pre-training transformers
Last synced: 5 months ago
JSON representation
Implementation of COCO-LM, Correcting and Contrasting Text Sequences for Language Model Pretraining, in Pytorch
- Host: GitHub
- URL: https://github.com/lucidrains/coco-lm-pytorch
- Owner: lucidrains
- License: mit
- Created: 2021-03-02T17:57:55.000Z (over 4 years ago)
- Default Branch: main
- Last Pushed: 2021-03-03T20:32:45.000Z (over 4 years ago)
- Last Synced: 2025-04-15T01:05:25.823Z (6 months ago)
- Topics: artificial-intelligence, deep-learning, pre-training, transformers
- Language: Python
- Homepage:
- Size: 120 KB
- Stars: 45
- Watchers: 5
- Forks: 7
- Open Issues: 2
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
## COCO LM Pretraining (wip)
Implementation of COCO-LM, Correcting and Contrasting Text Sequences for Language Model Pretraining, in Pytorch. They were able to make contrastive learning work in a self-supervised manner for language model pretraining. Seems like a solid successor to Electra.
## Install
```bash
$ pip install coco-lm-pytorch
```## Usage
An example using the `x-transformers` library
```bash
$ pip install x-transformers
```
Then```python
import torch
from coco_lm_pytorch import COCO# (1) instantiate the generator and discriminator, making sure that the generator is roughly a quarter to a half of the size of the discriminator
from x_transformers import TransformerWrapper, Encoder
generator = TransformerWrapper(
num_tokens = 20000,
emb_dim = 128,
max_seq_len = 1024,
attn_layers = Encoder(
dim = 256, # smaller hidden dimension
heads = 4, # less heads
ff_mult = 2, # smaller feedforward dimension
depth = 1
)
)discriminator = TransformerWrapper(
num_tokens = 20000,
emb_dim = 128,
max_seq_len = 1024,
attn_layers = Encoder(
dim = 1024,
heads = 16,
ff_mult = 4,
depth = 12
)
)# (2) weight tie the token and positional embeddings of generator and discriminator
generator.token_emb = discriminator.token_emb
generator.pos_emb = discriminator.pos_emb# weight tie any other embeddings if available, token type embeddings, etc.
# (3) instantiate COCO
trainer = COCO(
generator,
discriminator,
discr_dim = 1024, # the embedding dimension of the discriminator
discr_layer = 'norm', # the layer name in the discriminator, whose output would be used for predicting token is still the same or replaced
cls_token_id = 1, # a token id must be reserved for [CLS], which is prepended to the sequence for contrastive learning
mask_token_id = 2, # the token id reserved for masking
pad_token_id = 0, # the token id for padding
mask_prob = 0.15, # masking probability for masked language modeling
mask_ignore_token_ids = [], # ids of tokens to ignore for mask modeling ex. (cls, sep)
cl_weight = 1., # weight for the contrastive learning loss
disc_weight = 1., # weight for the corrective learning loss
gen_weight = 1. # weight for the MLM loss
)# (4) train
data = torch.randint(0, 20000, (1, 1024))
loss = trainer(data)
loss.backward()# after much training, the discriminator should have improved
torch.save(discriminator, f'./pretrained-model.pt')
```## Citations
```bibtex
@misc{meng2021cocolm,
title = {COCO-LM: Correcting and Contrasting Text Sequences for Language Model Pretraining},
author = {Yu Meng and Chenyan Xiong and Payal Bajaj and Saurabh Tiwary and Paul Bennett and Jiawei Han and Xia Song},
year = {2021},
eprint = {2102.08473},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}
```