Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/ChenghaoMou/pytorch-pQRNN
Implementation of pQRNN in PyTorch
https://github.com/ChenghaoMou/pytorch-pQRNN
nlp pqrnn pytorch text-classification
Last synced: about 1 month ago
JSON representation
Implementation of pQRNN in PyTorch
- Host: GitHub
- URL: https://github.com/ChenghaoMou/pytorch-pQRNN
- Owner: ChenghaoMou
- License: mit
- Archived: true
- Created: 2020-10-18T19:16:56.000Z (about 4 years ago)
- Default Branch: main
- Last Pushed: 2021-10-10T21:34:41.000Z (about 3 years ago)
- Last Synced: 2024-10-07T18:48:08.762Z (2 months ago)
- Topics: nlp, pqrnn, pytorch, text-classification
- Language: Python
- Homepage:
- Size: 264 KB
- Stars: 46
- Watchers: 3
- Forks: 9
- Open Issues: 2
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
- StarryDivineSky - ChenghaoMou/pytorch-pQRNN - RNN编码器来进行快速并行处理。pQRNN模型表明这种新的体系结构几乎可以达到BERT级的性能,尽管只使用1/300的参数量和有监督的数据。 (其他_NLP自然语言处理 / 其他_文本生成、文本对话)
README
![banner](./banner.png)
![PyPI](https://img.shields.io/pypi/v/pytorch-pqrnn?style=plastic) ![Maintenance](https://img.shields.io/maintenance/yes/2021?style=plastic) ![PyPI - License](https://img.shields.io/pypi/l/pytorch-pqrnn?style=plastic) [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.4661601.svg)](https://doi.org/10.5281/zenodo.4661601)
## Installation
```bash
# install with pypi
pip install pytorch-pqrnn
# or install locally with poetry
poetry install
```## Environment
Because of [this issue](https://github.com/salesforce/pytorch-qrnn/issues/29), `pytorch-qrnn` is no longer compatible with pytorch and it is also not actively maintained. If you want to use a QRNN layer in this model, you have install `pytorch-qrnn` with `torch <= 1.4` first.
## Usage
```python
from pytorch_pqrnn.dataset import create_dataloaders
from pytorch_pqrnn.model import PQRNNmodel = PQRNN(
b=128,
d=96,
lr=1e-3,
num_layers=2,
dropout=0.5,
output_size=5,
rnn_type="GRU",
multilabel=False,
nhead=2, # used when rnn_type == "Transformer"
)# Or load the model from your checkpoint
# model = PQRNN.load_from_checkpoint(checkpoint_path="example.ckpt")# Text data has to be pre-processed with DummyDataset
dataset = DummyDataset(
df[["text", "label"]].to_dict("records"),
has_label=True,
feature_size=128 * 2,
add_eos_tag=True,
add_bos_tag=True,
max_seq_len=512,
label2index={"pos": 1, "neg": 0},
)# Explicit train/val loop
# Add model.eval() when necessary
dataloader = create_dataloaders(dataset)
for batch in dataloader:
# labels could be an empty tensor if has_label is False when creating the dataset.
# To change what are included in a batch, feel free to change the collate_fn function
# in dataset.py
projections, lengths, labels = batch
logits = model.forward(projections)# do your magic
```## CLI Usage
```bash
Usage: run.py [OPTIONS]Options:
--task [yelp2|yelp5|toxic] [default: yelp5]
--b INTEGER [default: 128]
--d INTEGER [default: 96]
--num_layers INTEGER [default: 2]
--batch_size INTEGER [default: 512]
--dropout FLOAT [default: 0.5]
--lr FLOAT [default: 0.001]
--nhead INTEGER [default: 4]
--rnn_type [LSTM|GRU|QRNN|Transformer]
[default: GRU]
--data_path TEXT
--help Show this message and exit.
```Datasets
- yelp2(polarity): it will be downloaded w/ huggingface/datasets automatically
- yelp5: [json file](https://www.kaggle.com/luisfredgs/hahnn-for-document-classification?select=yelp_reviews.json) should be downloaded to into `data_path`
- toxic: [dataset](https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge) should be downloaded and unzipped to into `data_path`### Example: Yelp Polarity
python -W ignore run.py --task yelp2 --b 128 --d 64 --num_layers 4 --data_path data/
## Benchmarks(not optimized)
| Model | Model Size | Yelp Polarity (error rate) | Yelp-5 (accuracy) | Civil Comments (mean auroc) | Command |
| ------------------------ | ---------- | -------------------------- | ----------------- | --------------------------- | ---------------------------------------------------------------- |
| ~~PQRNN (this repo)~~0 | ~~78K~~ | ~~6.3~~ | ~~70.4~~ | ~~TODO~~ | `--b 128 --d 64 --num_layers 4 --rnn_type QRNN` |
| PRNN (this repo) | 90K | 5.5 | **70.7** | 95.57 | `--b 128 --d 64 --num_layers 1 --rnn_type GRU` |
| PTransformer (this repo) | 618K | 10.8 | 68 | 92.4 | `--b 128 --d 64 --num_layers 1 --rnn_type Transformer --nhead 8` |
| PRADO1 | 175K | | 65.9 | | |
| BERT | 335M | **1.81** | 70.58 | **98.856**2 | |
0. Not supported with `torch >= 1.7`
1. [Paper](https://www.aclweb.org/anthology/D19-1506.pdf)
2. Best Kaggle Submission## Credits
- [original tensorflow source code for PRADO](https://github.com/tensorflow/models/tree/master/research/sequence_projection/prado)
- Powered by [pytorch-lightning](https://github.com/PyTorchLightning/pytorch-lightning) and [grid.ai](https://www.grid.ai/)
## Citation
```
@software{chenghao_mou_2021_4661601,
author = {Chenghao MOU},
title = {ChenghaoMou/pytorch-pQRNN: Add DOI},
month = apr,
year = 2021,
publisher = {Zenodo},
version = {0.0.3},
doi = {10.5281/zenodo.4661601},
url = {https://doi.org/10.5281/zenodo.4661601}
}
```