13
13
TRAIN_FILE = '/root/imagenet-data/train-00001-of-01024'
14
14
VALIDATION_FILE = '/root/imagenet-data/validation-00004-of-00128'
15
15
16
- def data_files ():
17
- tf_record_pattern = os .path .join (FLAGS .data_dir , '%s-*' % FLAGS . train )
16
+ def data_files (dataset ):
17
+ tf_record_pattern = os .path .join (FLAGS .data_dir , '%s-*' % dataset )
18
18
data_files = tf .gfile .Glob (tf_record_pattern )
19
19
return data_files
20
20
21
21
def run_training ():
22
- #tf.reset_default_graph()
23
- data_files_ = TRAIN_FILE
22
+ #data_files_ = TRAIN_FILE
24
23
#data_files_ = VALIDATION_FILE
25
- # data_files_ = data_files()
24
+ data_files_ = data_files (FLAGS . train_or_validation )
26
25
images , labels = image_processing .distorted_inputs (
27
- [data_files_ ], FLAGS .num_epochs , batch_size = FLAGS .batch_size )
26
+ data_files_ , FLAGS .num_epochs , batch_size = FLAGS .batch_size )
27
+
28
28
labels = tf .one_hot (labels , 1000 )
29
29
logits = inference (images )
30
30
loss = tf .reduce_mean (tf .nn .softmax_cross_entropy_with_logits (
@@ -57,12 +57,9 @@ def run_training():
57
57
_ , loss_value , pred , acc = sess .run (
58
58
[train_op , loss , correct_pred , accuracy ])
59
59
duration = time .time () - start_batch
60
- if step % 10 == 0 :
60
+ if step % 100 == 0 :
61
61
print ('Step %d | loss = %.2f | accuracy = %.2f (%.3f sec/batch)' )% (
62
62
step , loss_value , acc , duration )
63
- if step % 500 == 0 :
64
- summary = sess .run (merged_summary_op )
65
- summary_writer .add_summary (summary , step * FLAGS .batch_size )
66
63
if step % 5000 == 0 :
67
64
saver .save (sess , FLAGS .model_path )
68
65
@@ -75,8 +72,58 @@ def run_training():
75
72
coord .join (threads )
76
73
sess .close ()
77
74
75
+ def evaluation ():
76
+ #data_files_ = TRAIN_FILE
77
+ data_files_ = data_files (FLAGS .train_or_validation )
78
+ images , labels = image_processing .inputs (
79
+ data_files_ , FLAGS .num_epochs , batch_size = FLAGS .batch_size )
80
+
81
+ labels = tf .one_hot (labels , 1000 )
82
+ logits = inference (images )
83
+ correct_pred = tf .equal (tf .arg_max (logits ,1 ), tf .argmax (labels ,1 ))
84
+ accuracy = tf .reduce_mean (tf .cast (correct_pred , tf .float32 ))
85
+
86
+ init_op = tf .group (tf .global_variables_initializer (), tf .local_variables_initializer ())
87
+ sess = tf .Session ()
88
+ sess .run (init_op )
89
+ coord = tf .train .Coordinator ()
90
+ threads = tf .train .start_queue_runners (sess = sess , coord = coord )
91
+
92
+ #save/restore model
93
+ d = {}
94
+ l = ['w1' , 'b1' , 'w2' , 'b2' , 'w3' , 'b3' , 'w4' , 'b4' , 'w5' , 'b5' , 'w_fc1' , 'b_fc1' , 'w_fc2' , 'b_fc2' , 'w_output' , 'b_output' ]
95
+ for i in l :
96
+ d [i ] = [v for v in tf .get_collection (tf .GraphKeys .GLOBAL_VARIABLES ) if v .name == i + ':0' ][0 ]
97
+ saver = tf .train .Saver (d )
98
+ saver .restore (sess , FLAGS .model_path )
99
+
100
+ try :
101
+ step = 0
102
+ start_time = time .time ()
103
+ while not coord .should_stop ():
104
+ start_batch = time .time ()
105
+ acc = sess .run (accuracy )
106
+ duration = time .time () - start_batch
107
+ print ('Step %d | accuracy = %.2f (%.3f sec/batch)' )% (
108
+ step , acc , duration )
109
+ step += 1
110
+ except tf .errors .OutOfRangeError :
111
+ print ('Done evaluating for %d epochs, %d steps, %.1f min.' % (FLAGS .num_epochs , step , (time .time ()- start_time )/ 60 ))
112
+ finally :
113
+ coord .request_stop ()
114
+
115
+ coord .join (threads )
116
+ sess .close ()
117
+
78
118
def main (_ ):
79
- run_training ()
119
+ if FLAGS .train_or_validation == 'train' :
120
+ print ' *** run training.'
121
+ print FLAGS .train_or_validation
122
+ run_training ()
123
+ else :
124
+ print ' *** run validation.'
125
+ print FLAGS .train_or_validation
126
+ evaluation ()
80
127
81
128
if __name__ == '__main__' :
82
129
parser = argparse .ArgumentParser ()
@@ -117,7 +164,7 @@ def main(_):
117
164
help = 'Variables for the model.'
118
165
)
119
166
parser .add_argument (
120
- '--train ' ,
167
+ '--train_or_validation ' ,
121
168
type = str ,
122
169
default = 'train' ,
123
170
help = 'Train or evaluate the dataset'
0 commit comments