Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/devsisters/pointer-network-tensorflow
TensorFlow implementation of "Pointer Networks"
https://github.com/devsisters/pointer-network-tensorflow
deep-learning pointer-networks tensorflow
Last synced: 5 days ago
JSON representation
TensorFlow implementation of "Pointer Networks"
- Host: GitHub
- URL: https://github.com/devsisters/pointer-network-tensorflow
- Owner: devsisters
- License: mit
- Created: 2017-01-17T02:53:52.000Z (almost 8 years ago)
- Default Branch: master
- Last Pushed: 2017-03-01T02:12:28.000Z (almost 8 years ago)
- Last Synced: 2024-12-30T12:03:43.096Z (12 days ago)
- Topics: deep-learning, pointer-networks, tensorflow
- Language: Python
- Homepage:
- Size: 740 KB
- Stars: 471
- Watchers: 24
- Forks: 138
- Open Issues: 13
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Pointer Networks in Tensorflow
TensorFlow implementation of [Pointer Networks](https://arxiv.org/abs/1506.03134).
![model](./assets/model.png)
*Support multithreaded data pipelines to reduce I/O latency.*
## Requirements
- Python 2.7
- [requests](http://python-requests.org/)
- [tqdm](https://github.com/tqdm/tqdm)
- [TensorFlow 0.12.1](https://github.com/tensorflow/tensorflow/tree/r0.12)## Usage
To train a model:
$ python main.py --task=tsp --max_data_length=20 --hidden_dim=512 # download dataset used in the paper
$ python main.py --task=tsp --max_data_length=10 --hidden_dim=128 # generate dataset itselfTo train a model:
$ python main.py
$ tensorboard --logdir=logs --host=0.0.0.0To test a model:
$ python main.py --is_train=False
## Results
Train/Test loss of `max_data_length=10` after 24,000 steps:
$ python main.py --task=tsp --max_data_length=10 --hidden_dim=128
$ ...
$ 2017-01-18 17:17:48,054:INFO::
$ 2017-01-18 17:17:48,054:INFO::test loss: 0.00149279157631
$ 2017-01-18 17:17:48,055:INFO::test pred: [1 6 2 5 4 3 0 0 0 0 0]
$ 2017-01-18 17:17:48,057:INFO::test true: [1 6 2 5 4 3 0 0 0 0 0] (True)
$ 2017-01-18 17:17:48,059:INFO::test pred: [1 3 8 4 7 5 2 6 0 0 0]
$ 2017-01-18 17:17:48,059:INFO::test true: [1 3 8 4 7 5 2 6 0 0 0] (True)
$ 2017-01-18 17:17:48,058:INFO::test pred: [ 1 8 3 7 9 5 6 4 2 10 0]
$ 2017-01-18 17:17:48,058:INFO::test true: [ 1 8 3 7 9 5 6 4 2 10 0] (True)![model](./assets/max_data_length=10_step=24000.png)
Train/Test loss of `max_data_length=20` after 16,000 steps:
$ python main.py --task=tsp --max_data_length=20 --hidden_dim=512
$ ...
$ 2017-01-18 17:30:44,738:INFO::
$ 2017-01-18 17:30:44,739:INFO::test loss: 0.0501566007733
$ 2017-01-18 17:30:44,739:INFO::test pred: [ 1 3 16 18 6 2 7 4 5 15 10 8 13 11 17 20 19 14 9 12 0]
$ 2017-01-18 17:30:44,740:INFO::test true: [ 1 3 16 18 6 2 7 4 5 15 10 8 13 11 17 20 19 14 9 12 0] (True)
$ 2017-01-18 17:30:44,741:INFO::test pred: [ 1 4 12 8 2 17 20 3 18 14 16 11 13 15 7 10 9 6 19 5 0]
$ 2017-01-18 17:30:44,741:INFO::test true: [ 1 4 12 8 2 17 20 3 18 14 16 11 13 15 7 10 9 6 19 5 0] (True)
$ 2017-01-18 17:30:44,742:INFO::test pred: [ 1 8 16 20 19 14 18 2 9 12 15 11 3 13 5 6 7 4 17 10 0]
$ 2017-01-18 17:30:44,742:INFO::test true: [ 1 8 16 20 19 14 18 2 9 12 15 11 3 13 5 6 7 4 17 10 0] (True)![model](./assets/max_data_length=20_step=16000.png)
## Author
Taehoon Kim / [@carpedm20](http://carpedm20.github.io)