https://github.com/vveitch/causal-text-embeddings-tf2
Tensorflow 2 implementation of Causal-BERT
https://github.com/vveitch/causal-text-embeddings-tf2
Last synced: 8 months ago
JSON representation
Tensorflow 2 implementation of Causal-BERT
- Host: GitHub
- URL: https://github.com/vveitch/causal-text-embeddings-tf2
- Owner: vveitch
- Created: 2019-12-19T19:29:57.000Z (almost 6 years ago)
- Default Branch: master
- Last Pushed: 2023-11-05T06:14:53.000Z (almost 2 years ago)
- Last Synced: 2024-10-27T18:58:03.624Z (about 1 year ago)
- Language: Python
- Homepage:
- Size: 4.86 MB
- Stars: 69
- Watchers: 3
- Forks: 19
- Open Issues: 3
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
- awesome-tensorflow-2 - Tensorflow 2 implementation of Causal-BERT
README
# Causal-Bert TF2
This is a reference Tensorflow 2.1 / Keras implementation of the "causal bert" method described in [Using Text Embeddings for Causal Inference](arxiv.org/abs/1905.12741).
This method provides a way to estimate causal effects when either
(1) a treatment and outcome are both influenced by confounders, and information about the confounding is contained in a
text passage. For example, we consider estimating the effect of adding a theorem to a paper on whether or not the paper
is accepted at a computer science conference, adjusting for the paper's abstract (topic, writing quality, etc)
(2) a treatment affecting an outcome is mediated by text. For example, we consider whether the score of a reddit post is
affected by publicly listing the gender of the author, adjusting for the text of the post
This is a reference implementation to make it easier for others to use and build on the project. The official code,
including instructions to reproduce the experiments, is available [here](https://github.com/blei-lab/causal-text-embeddings). (In Tensorflow 1.13)
There is also a [reference implementation in pytorch.](https://github.com/rpryzant/causal-bert-pytorch)
All code in tf_official is taken from https://github.com/tensorflow/models/tree/master/official
(and subject to their liscensing requirements)
# Instructions
1. Download BERT-Base, Uncased pre-trained model following instructions at https://github.com/tensorflow/models/tree/master/official/nlp/bert
Extract to ../pre-trained/uncased_L-12_H-768_A-12
2. in src/
```
run python -m PeerRead.model.run_causal_bert \
--input_files=../dat/PeerRead/proc/arxiv-all.tf_record \
--bert_config_file=../pre-trained/uncased_L-12_H-768_A-12/bert_config.json \
--init_checkpoint=../pre-trained/uncased_L-12_H-768_A-12/bert_model.ckpt \
--vocab_file=../pre-trained/uncased_L-12_H-768_A-12/vocab.txt \
--seed=0 \
--strategy_type=mirror \
--train_batch_size=32
```
# Notes
1. This reference implementation doesn't necessarily reproduce paper results---I haven't messed around w/ weighting of unsupervised and supervised losses
2. PeerRead data from: github.com/allenai/PeerRead
3. Model performance is usually significantly improved by doing unsupervised pre-training on your dataset.
See PeerRead/model/run_pretraining for how to do this