Skip to content

Commit deffe1c

Browse files
committed
initial commit
0 parents  commit deffe1c

11 files changed

+891
-0
lines changed

.gitignore

+124
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Data
2+
data/hand
3+
data/gaze
4+
data/*
5+
samples
6+
outputs
7+
8+
# Log
9+
logs
10+
11+
# ETC
12+
paper.pdf
13+
14+
# Created by https://www.gitignore.io/api/python,vim
15+
16+
### Python ###
17+
# Byte-compiled / optimized / DLL files
18+
__pycache__/
19+
*.py[cod]
20+
*$py.class
21+
22+
# C extensions
23+
*.so
24+
25+
# Distribution / packaging
26+
.Python
27+
env/
28+
build/
29+
develop-eggs/
30+
dist/
31+
downloads/
32+
eggs/
33+
.eggs/
34+
lib/
35+
lib64/
36+
parts/
37+
sdist/
38+
var/
39+
wheels/
40+
*.egg-info/
41+
.installed.cfg
42+
*.egg
43+
44+
# PyInstaller
45+
# Usually these files are written by a python script from a template
46+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
47+
*.manifest
48+
*.spec
49+
50+
# Installer logs
51+
pip-log.txt
52+
pip-delete-this-directory.txt
53+
54+
# Unit test / coverage reports
55+
htmlcov/
56+
.tox/
57+
.coverage
58+
.coverage.*
59+
.cache
60+
nosetests.xml
61+
coverage.xml
62+
*,cover
63+
.hypothesis/
64+
65+
# Translations
66+
*.mo
67+
*.pot
68+
69+
# Django stuff:
70+
*.log
71+
local_settings.py
72+
73+
# Flask stuff:
74+
instance/
75+
.webassets-cache
76+
77+
# Scrapy stuff:
78+
.scrapy
79+
80+
# Sphinx documentation
81+
docs/_build/
82+
83+
# PyBuilder
84+
target/
85+
86+
# Jupyter Notebook
87+
.ipynb_checkpoints
88+
89+
# pyenv
90+
.python-version
91+
92+
# celery beat schedule file
93+
celerybeat-schedule
94+
95+
# dotenv
96+
.env
97+
98+
# virtualenv
99+
.venv/
100+
venv/
101+
ENV/
102+
103+
# Spyder project settings
104+
.spyderproject
105+
106+
# Rope project settings
107+
.ropeproject
108+
109+
110+
### Vim ###
111+
# swap
112+
[._]*.s[a-v][a-z]
113+
[._]*.sw[a-p]
114+
[._]s[a-v][a-z]
115+
[._]sw[a-p]
116+
# session
117+
Session.vim
118+
# temporary
119+
.netrwhist
120+
*~
121+
# auto-generated tag files
122+
tags
123+
124+
# End of https://www.gitignore.io/api/python,vim

LICENSE

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
The MIT License (MIT)
2+
3+
Copyright (c) 2016 Devsisters corp.
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

README.md

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Pointer Networks in Tensorflow
2+
3+
TensorFlow implementation of [Pointer Networks](https://arxiv.org/abs/1506.03134).
4+
5+
![model](./assets/model.png)
6+
7+
(in progress)
8+
9+
10+
## Requirements
11+
12+
- Python 2.7
13+
- [tqdm](httsp://github.com/tqdm/tqdm)
14+
- [TensorFlow 0.12.1](httsp://github.com/tensorflow/tensorflow/tree/r0.12)
15+
16+
17+
## Usage
18+
19+
To train a model:
20+
21+
$ python main.py --task=tsp20 --lr_start=0.001 --min_data_length=5 --max_data_length=20
22+
$ python main.py --task=tsp50 --lr_start=0.001 --min_data_length=5 --max_data_length=50
23+
$ python main.py --task=tsp100 --lr_start=0.0001 --min_data_length=5 --max_data_length=100
24+
25+
26+
To train a model:
27+
28+
$ python main.py
29+
$ tensorboard --logdir=logs --host=0.0.0.0
30+
31+
To test a model:
32+
33+
$ python main.py --is_train=False
34+
35+
## Results
36+
37+
(in progress)
38+
39+
40+
## Author
41+
42+
Taehoon Kim / [@carpedm20](http://carpedm20.github.io)

assets/model.png

401 KB
Loading

config.py

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#-*- coding: utf-8 -*-
2+
import argparse
3+
4+
def str2bool(v):
5+
return v.lower() in ('true', '1')
6+
7+
arg_lists = []
8+
parser = argparse.ArgumentParser()
9+
10+
def add_argument_group(name):
11+
arg = parser.add_argument_group(name)
12+
arg_lists.append(arg)
13+
return arg
14+
15+
# Network
16+
net_arg = add_argument_group('Network')
17+
net_arg.add_argument('--hidden_dim', type=int, default=128, help='')
18+
net_arg.add_argument('--num_layers', type=int, default=1, help='')
19+
net_arg.add_argument('--input_dim', type=int, default=2, help='')
20+
net_arg.add_argument('--max_enc_length', type=int, default=None, help='')
21+
net_arg.add_argument('--max_dec_length', type=int, default=None, help='')
22+
net_arg.add_argument('--init_min_val', type=float, default=-0.08, help='for uniform random initializer')
23+
net_arg.add_argument('--init_max_val', type=float, default=+0.08, help='for uniform random initializer')
24+
net_arg.add_argument('--num_glimpse', type=int, default=1, help='')
25+
net_arg.add_argument('--use_terminal_symbol', type=str2bool, default=True, help='Not implemented yet')
26+
27+
# Data
28+
data_arg = add_argument_group('Data')
29+
data_arg.add_argument('--task', type=str, default='TSP')
30+
data_arg.add_argument('--batch_size', type=int, default=128)
31+
data_arg.add_argument('--min_data_length', type=int, default=5)
32+
data_arg.add_argument('--max_data_length', type=int, default=10)
33+
data_arg.add_argument('--train_num', type=int, default=1000000)
34+
data_arg.add_argument('--valid_num', type=int, default=1000)
35+
data_arg.add_argument('--test_num', type=int, default=1000)
36+
37+
# Training / test parameters
38+
train_arg = add_argument_group('Training')
39+
train_arg.add_argument('--is_train', type=str2bool, default=True, help='')
40+
train_arg.add_argument('--optimizer', type=str, default='rmsprop', help='')
41+
train_arg.add_argument('--max_step', type=int, default=1000000, help='')
42+
train_arg.add_argument('--lr_start', type=float, default=0.001, help='')
43+
train_arg.add_argument('--lr_decay_step', type=int, default=5000, help='')
44+
train_arg.add_argument('--lr_decay_rate', type=float, default=0.96, help='')
45+
train_arg.add_argument('--max_grad_norm', type=float, default=1.0, help='')
46+
train_arg.add_argument('--checkpoint_secs', type=int, default=300, help='')
47+
48+
# Misc
49+
misc_arg = add_argument_group('Misc')
50+
misc_arg.add_argument('--log_step', type=int, default=20, help='')
51+
misc_arg.add_argument('--log_level', type=str, default='INFO', choices=['INFO', 'DEBUG', 'WARN'], help='')
52+
misc_arg.add_argument('--log_dir', type=str, default='logs')
53+
misc_arg.add_argument('--data_dir', type=str, default='data')
54+
misc_arg.add_argument('--output_dir', type=str, default='outputs')
55+
misc_arg.add_argument('--load_path', type=str, default='')
56+
misc_arg.add_argument('--debug', type=str2bool, default=False)
57+
misc_arg.add_argument('--gpu_memory_fraction', type=float, default=1.0)
58+
misc_arg.add_argument('--random_seed', type=int, default=123, help='')
59+
60+
def get_config():
61+
config, unparsed = parser.parse_known_args()
62+
return config, unparsed

0 commit comments

Comments
 (0)