https://github.com/thunlp-mt/transformer-dmb
Codes for our paper "Dynamic Multi-Branch Layers for On-Device Neural Machine Translation" in TASLP
https://github.com/thunlp-mt/transformer-dmb
Last synced: 3 months ago
JSON representation
Codes for our paper "Dynamic Multi-Branch Layers for On-Device Neural Machine Translation" in TASLP
- Host: GitHub
- URL: https://github.com/thunlp-mt/transformer-dmb
- Owner: THUNLP-MT
- License: bsd-3-clause
- Created: 2022-02-08T05:02:27.000Z (over 4 years ago)
- Default Branch: main
- Last Pushed: 2022-03-18T08:45:31.000Z (about 4 years ago)
- Last Synced: 2024-12-31T05:19:00.399Z (over 1 year ago)
- Language: Python
- Homepage:
- Size: 58.6 KB
- Stars: 2
- Watchers: 3
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Dynamic Multi-Branch Layers for On-Device Neural Machine Translation
This is the codes for our work [Dynamic Multi-Branch Layers for On-Device Neural Machine Translation](https://ieeexplore.ieee.org/document/9729651). The implementation is on top of the open-source NMT toolkit [THUMT](https://github.com/THUNLP-MT/THUMT).
## Contents
* [Prerequisites](#prerequisites)
* [Training](#training)
* [Decoding](#decoding)
* [Profile](#profile)
* [License](#license)
* [Citation](#citation)
## Prerequisites
* Python >= 3.7
* tensorflow-cpu >= 2.0
* torch >= 1.7
* torchprofile
Please read the document of [THUMT](https://github.com/THUNLP-MT/THUMT/blob/master/docs/index.md) before using this Repository.
## Training
Training a Transformer-DMB model is nearly the same as a Transformer model, except for the following additional hyperparameters:
* `n`: the number of branches for each layer
* `shared_private`: enable shared-private reparameterization
```
export PYTHONPATH=:$PYTHONPATH
python thumt/bin/trainer.py \
--input \
--model transformer_dmb \
--vocabulary \
--parameters=hidden_size=128,filter_size=512,device_list=[0,1,2,3],\
update_cycle=4,train_steps=100000,\
shared_source_target_embedding=true,\
shared_embedding_and_softmax_weights=true,n=4,\
shared_private=true \
--hparam_set base
```
## Decoding
The following command decodes an input file with a pre-trained checkpoint.
```
export PYTHONPATH=:$PYTHONPATH
python thumt/bin/translator.py \
--input \
--model transformer_dmb \
--vocabulary \
--parameters=device_list=[0,1,2,3] \
--checkpoint \
--output
```
## Profile
Use the following command to count the number of MultAdds of a tiny Transformer-DMB model with 8 branches.
```
export PYTHONPATH=:$PYTHONPATH
python thumt/bin/trainer.py \
--input \
--model transformer_dmb \
--vocabulary \
--parameters=hidden_size=128,filter_size=512,n=8 \
--hparam_set base --profile --output /tmp
```
## License
Open source licensing is under the [BSD-3-Clause](https://opensource.org/licenses/BSD-3-Clause), which allows free use for research purposes.
## Citation
```
@article{tan2022dynamic,
author={Tan, Zhixing and Yang, Zeyuan and Zhang, Meng and Liu, Qun and Sun, Maosong and Liu, Yang},
journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing},
title={Dynamic Multi-Branch Layers for On-Device Neural Machine Translation},
year={2022},
volume={30},
pages={958--967},
doi={10.1109/TASLP.2022.3153257}}
```