Ecosyste.ms: Awesome

An open API service indexing awesome lists of open source software.

Awesome Lists | Featured Topics | Projects

https://github.com/rbroc/ctx_transformers

Training encoder for context-aware MLM and user encoding
https://github.com/rbroc/ctx_transformers

Last synced: 28 days ago
JSON representation

Training encoder for context-aware MLM and user encoding

Awesome Lists containing this project

README

        

## Training transformers with structured context
#### Short summary
This project focuses on training transformer encoders that incorporate higher-order context (i.e., other text from the author and/or the subreddit) in encodings for a target text sequence.
The rationale is that:
a) knowledge of individuals styles and/or styles associated to a given topic could produce better MLM predictions;
b) training on this context-aware version of MLM can simultaneously yield text representations at three levels: token, sequence, and potentialluy separable context representations;
c) aggregate representations of mulitple sequences from the same context could be used as "user" encodings, e.g., for prediction of individual behavior or traits.

The idea behind context-aware MLM pretraining is the following.
We feed models a target sequence and a number of ‘context’ sequences (i.e., text from the same author, or from the same subreddit) as a single example, and train models on a variant of MLM where the MLM head is fed the combination (through sum, concatenation, or token-context attention) of token-level representations from the target sequence and an aggregate representation of the contexts.

#### Models
We experiment with three DistilBERT-inspired architectures: a bi-encoder (where context and target are fed to two separate encoders), a ‘batch’ encoder (single encoder with added context aggregation and target-context combination layers) and a hierarchical encoder (applying attention across [CLS] tokens in between standard transformer layers to integrate information across contexts and target sequence). The benefits of this training protocol are evaluated both by comparing their MLM performance with no-context MLM training and to random-context training, as well as on a triplet-loss author/subreddit discrimination task. We also experiment with selective masking of attention heads based on the type of context provided (author vs subreddit) to simultaneously produce separable context representations.
They look (roughly) like this:

![img](misc/architectures.jpeg)

#### Rationale
The importance of this project is two-fold. First, this way of tuning models to produce context-aware representations may provide intrinsic advantages in NLP tasks (e.g., for MLM, knowledge of the author may make it easier to accurately reconstruct missing words), without substantial increase in model complexity. Secondly, text-based representations of text authors could be used to predict individual traits, following the intuition that linguistic behavior is systematically influenced by personality, experiences, etc.

#### Some notes on the repo
The Reddit dataset generated through the code in ```reddit/preprocessing``` is not uploaded for storage reasons. It is a ```TFRecords``` datasets that gets streamed during training.
Models are defined in ```reddit/models.py``` and auxiliary layers are defined in ```reddit/layers.py```. I am still experimenting with the architectures, so there's a few parameters to set for each of those.
Functions for (distributed) training and logging are in ```reddit/training.py``` and ```reddit/logging.py``` respectively.

#### Status
This project is still in progress.