Skip to content

Commit 06a1930

Browse files
committed
Support different label type of tfrecords
1 parent 2721b0a commit 06a1930

File tree

3 files changed

+100
-78
lines changed

3 files changed

+100
-78
lines changed

dense_classifier.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
import util
1717
import model
1818

19+
logging.basicConfig(
20+
format='%(asctime)s %(levelname)-8s %(message)s',
21+
level=logging.INFO,
22+
datefmt='%Y-%m-%d %H:%M:%S')
23+
1924

2025
def define_flags():
2126
"""
@@ -89,14 +94,17 @@ def define_flags():
8994
])
9095

9196
# Print flags
97+
FLAGS.mode
9298
parameter_value_map = {}
9399
for key in FLAGS.__flags.keys():
94100
parameter_value_map[key] = FLAGS.__flags[key].value
95101
pprint.PrettyPrinter().pprint(parameter_value_map)
96-
97102
return FLAGS
98103

99104

105+
FLAGS = define_flags()
106+
107+
100108
def parse_tfrecords_function(example_proto):
101109
"""
102110
Decode TFRecords for Dataset.
@@ -175,10 +183,6 @@ def inference(inputs, input_units, output_units, is_train=True):
175183
FLAGS)
176184

177185

178-
logging.basicConfig(level=logging.INFO)
179-
FLAGS = define_flags()
180-
181-
182186
def main():
183187
"""
184188
Train the TensorFlow models.

sparse_classifier.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,13 @@
1414
signature_constants, signature_def_utils, tag_constants, utils)
1515

1616
import sparse_model
17-
import model
1817
import util
1918

19+
logging.basicConfig(
20+
format='%(asctime)s %(levelname)-8s %(message)s',
21+
level=logging.INFO,
22+
datefmt='%Y-%m-%d %H:%M:%S')
23+
2024

2125
def define_flags():
2226
"""
@@ -34,6 +38,7 @@ def define_flags():
3438
"The glob pattern of train TFRecords files")
3539
flags.DEFINE_integer("feature_size", 124, "Number of feature size")
3640
flags.DEFINE_integer("label_size", 2, "Number of label size")
41+
flags.DEFINE_string("label_type", "int", "The type of label")
3742
flags.DEFINE_float("learning_rate", 0.01, "The learning rate")
3843
flags.DEFINE_integer("epoch_number", 10, "Number of epochs to train")
3944
flags.DEFINE_integer("batch_size", 1024, "The batch size of training")
@@ -81,14 +86,17 @@ def define_flags():
8186
])
8287

8388
# Print flags
89+
FLAGS.mode
8490
parameter_value_map = {}
8591
for key in FLAGS.__flags.keys():
8692
parameter_value_map[key] = FLAGS.__flags[key].value
8793
pprint.PrettyPrinter().pprint(parameter_value_map)
88-
8994
return FLAGS
9095

9196

97+
FLAGS = define_flags()
98+
99+
92100
def parse_tfrecords_function(example_proto):
93101
"""
94102
Decode TFRecords for Dataset.
@@ -100,15 +108,31 @@ def parse_tfrecords_function(example_proto):
100108
The op of features and labels
101109
"""
102110

103-
features = {
104-
"ids": tf.VarLenFeature(tf.int64),
105-
"values": tf.VarLenFeature(tf.float32),
106-
"label": tf.FixedLenFeature([], tf.int64, default_value=0)
107-
}
111+
if FLAGS.label_type == "int":
112+
features = {
113+
"ids": tf.VarLenFeature(tf.int64),
114+
"values": tf.VarLenFeature(tf.float32),
115+
"label": tf.FixedLenFeature([], tf.int64, default_value=0)
116+
}
108117

109-
parsed_features = tf.parse_single_example(example_proto, features)
110-
return parsed_features["label"], parsed_features["ids"], parsed_features[
111-
"values"]
118+
parsed_features = tf.parse_single_example(example_proto, features)
119+
labels = parsed_features["label"]
120+
ids = parsed_features["ids"]
121+
values = parsed_features["values"]
122+
123+
elif FLAGS.label_type == "float":
124+
features = {
125+
"ids": tf.VarLenFeature(tf.int64),
126+
"values": tf.VarLenFeature(tf.float32),
127+
"label": tf.FixedLenFeature([], tf.float32, default_value=0)
128+
}
129+
130+
parsed_features = tf.parse_single_example(example_proto, features)
131+
labels = tf.cast(parsed_features["label"], tf.int32)
132+
ids = parsed_features["ids"]
133+
values = parsed_features["values"]
134+
135+
return labels, ids, values
112136

113137

114138
def inference(sparse_ids, sparse_values, is_train=True):
@@ -133,10 +157,6 @@ def inference(sparse_ids, sparse_values, is_train=True):
133157
is_train, FLAGS)
134158

135159

136-
logging.basicConfig(level=logging.INFO)
137-
FLAGS = define_flags()
138-
139-
140160
def main():
141161

142162
if os.path.exists(FLAGS.checkpoint_path) == False:
@@ -170,8 +190,8 @@ def main():
170190
validation_filename_placeholder = tf.placeholder(tf.string, shape=[None])
171191
validation_dataset = tf.data.TFRecordDataset(validation_filename_placeholder)
172192
validation_dataset = validation_dataset.map(parse_tfrecords_function).repeat(
173-
epoch_number).batch(FLAGS.validation_batch_size).shuffle(
174-
buffer_size=validation_buffer_size)
193+
).batch(FLAGS.validation_batch_size).shuffle(
194+
buffer_size=validation_buffer_size)
175195
validation_dataset_iterator = validation_dataset.make_initializable_iterator(
176196
)
177197
validation_labels, validation_ids, validation_values = validation_dataset_iterator.get_next(

util.py

Lines changed: 55 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
1-
21
from __future__ import absolute_import, division, print_function
32

43
import logging
54
import os
65
import tensorflow as tf
76
from tensorflow.python.saved_model import builder as saved_model_builder
8-
from tensorflow.python.saved_model import (
9-
signature_constants, tag_constants)
7+
from tensorflow.python.saved_model import (signature_constants, tag_constants)
108

119

1210
def get_optimizer_by_name(optimizer_name, learning_rate):
13-
"""
11+
"""
1412
Get optimizer object by the optimizer name.
1513
1614
Args:
@@ -21,30 +19,30 @@ def get_optimizer_by_name(optimizer_name, learning_rate):
2119
The optimizer object.
2220
"""
2321

24-
logging.info("Use the optimizer: {}".format(optimizer_name))
25-
if optimizer_name == "sgd":
26-
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
27-
elif optimizer_name == "adadelta":
28-
optimizer = tf.train.AdadeltaOptimizer(learning_rate)
29-
elif optimizer_name == "adagrad":
30-
optimizer = tf.train.AdagradOptimizer(learning_rate)
31-
elif optimizer_name == "adam":
32-
optimizer = tf.train.AdamOptimizer(learning_rate)
33-
elif optimizer_name == "ftrl":
34-
optimizer = tf.train.FtrlOptimizer(learning_rate)
35-
elif optimizer_name == "rmsprop":
36-
optimizer = tf.train.RMSPropOptimizer(learning_rate)
37-
else:
38-
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
39-
return optimizer
22+
logging.info("Use the optimizer: {}".format(optimizer_name))
23+
if optimizer_name == "sgd":
24+
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
25+
elif optimizer_name == "adadelta":
26+
optimizer = tf.train.AdadeltaOptimizer(learning_rate)
27+
elif optimizer_name == "adagrad":
28+
optimizer = tf.train.AdagradOptimizer(learning_rate)
29+
elif optimizer_name == "adam":
30+
optimizer = tf.train.AdamOptimizer(learning_rate)
31+
elif optimizer_name == "ftrl":
32+
optimizer = tf.train.FtrlOptimizer(learning_rate)
33+
elif optimizer_name == "rmsprop":
34+
optimizer = tf.train.RMSPropOptimizer(learning_rate)
35+
else:
36+
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
37+
return optimizer
4038

4139

4240
def save_model(model_path,
43-
model_version,
44-
sess,
45-
signature_def_map,
46-
is_save_graph=False):
47-
"""
41+
model_version,
42+
sess,
43+
signature_def_map,
44+
is_save_graph=False):
45+
"""
4846
Save the model in standard SavedModel format.
4947
5048
Args:
@@ -58,36 +56,36 @@ def save_model(model_path,
5856
None
5957
"""
6058

61-
export_path = os.path.join(model_path, str(model_version))
62-
if os.path.isdir(export_path) == True:
63-
logging.error("The model exists in path: {}".format(export_path))
64-
return
59+
export_path = os.path.join(model_path, str(model_version))
60+
if os.path.isdir(export_path) == True:
61+
logging.error("The model exists in path: {}".format(export_path))
62+
return
6563

66-
try:
67-
# Save the SavedModel
68-
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
69-
builder = saved_model_builder.SavedModelBuilder(export_path)
70-
builder.add_meta_graph_and_variables(
71-
sess, [tag_constants.SERVING],
72-
clear_devices=True,
73-
signature_def_map=signature_def_map,
74-
legacy_init_op=legacy_init_op)
75-
logging.info("Save the model in: {}".format(export_path))
76-
builder.save()
64+
try:
65+
# Save the SavedModel
66+
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
67+
builder = saved_model_builder.SavedModelBuilder(export_path)
68+
builder.add_meta_graph_and_variables(
69+
sess, [tag_constants.SERVING],
70+
clear_devices=True,
71+
signature_def_map=signature_def_map,
72+
legacy_init_op=legacy_init_op)
73+
logging.info("Save the model in: {}".format(export_path))
74+
builder.save()
7775

78-
# Save the GraphDef
79-
if is_save_graph == True:
80-
graph_file_name = "graph.pb"
81-
logging.info("Save the graph file in: {}".format(model_path))
82-
tf.train.write_graph(
83-
sess.graph_def, model_path, graph_file_name, as_text=False)
76+
# Save the GraphDef
77+
if is_save_graph == True:
78+
graph_file_name = "graph.pb"
79+
logging.info("Save the graph file in: {}".format(model_path))
80+
tf.train.write_graph(
81+
sess.graph_def, model_path, graph_file_name, as_text=False)
8482

85-
except Exception as e:
86-
logging.error("Fail to export saved model, exception: {}".format(e))
83+
except Exception as e:
84+
logging.error("Fail to export saved model, exception: {}".format(e))
8785

8886

8987
def restore_from_checkpoint(sess, saver, checkpoint_file_path):
90-
"""
88+
"""
9189
Restore session from checkpoint files.
9290
9391
Args:
@@ -98,11 +96,11 @@ def restore_from_checkpoint(sess, saver, checkpoint_file_path):
9896
Return:
9997
True if restore successfully and False if fail
10098
"""
101-
if checkpoint_file_path:
102-
logging.info(
103-
"Restore session from checkpoint: {}".format(checkpoint_file_path))
104-
saver.restore(sess, checkpoint_file_path)
105-
return True
106-
else:
107-
logging.error("Checkpoint not found: {}".format(checkpoint_file_path))
108-
return False
99+
if checkpoint_file_path:
100+
logging.info(
101+
"Restore session from checkpoint: {}".format(checkpoint_file_path))
102+
saver.restore(sess, checkpoint_file_path)
103+
return True
104+
else:
105+
logging.error("Checkpoint not found: {}".format(checkpoint_file_path))
106+
return False

0 commit comments

Comments
 (0)