Skip to content

Commit ed1e6d9

Browse files
committed
minor changes
1 parent 86529a0 commit ed1e6d9

File tree

3 files changed

+63
-15
lines changed

3 files changed

+63
-15
lines changed

grasp_img_proc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def distorted_inputs(data_files, num_epochs, train=True, batch_size=None):
146146
return images, bboxes
147147

148148

149-
def inputs(data_files, batch_size, num_epochs, train=False):
149+
def inputs(data_files, num_epochs, train=False, batch_size=None):
150150
with tf.device('/cpu:0'):
151151
images, bboxes = batch_inputs(
152152
data_files, train, num_epochs, batch_size,

image_processing.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ def batch_inputs(data_files, train, num_epochs, batch_size,
7878
shuffle=True,
7979
capacity=16)
8080
else:
81-
filename_queue = tf.train.string_input_producer(ata_files,
81+
filename_queue = tf.train.string_input_producer(data_files,
82+
num_epochs,
8283
shuffle=False,
8384
capacity=1)
8485

@@ -136,7 +137,7 @@ def distorted_inputs(data_files, num_epochs, train=True, batch_size=None):
136137
num_readers=FLAGS.num_readers)
137138
return images, labels
138139

139-
def inputs(data_files, batch_size, train=False, num_epochs=None):
140+
def inputs(data_files, num_epochs=None, train=False, batch_size=None):
140141
with tf.device('/cpu:0'):
141142
images, labels = batch_inputs(
142143
data_files, train, num_epochs, batch_size,

imagenet_classifier.py

+59-12
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,18 @@
1313
TRAIN_FILE = '/root/imagenet-data/train-00001-of-01024'
1414
VALIDATION_FILE = '/root/imagenet-data/validation-00004-of-00128'
1515

16-
def data_files():
17-
tf_record_pattern = os.path.join(FLAGS.data_dir, '%s-*' % FLAGS.train)
16+
def data_files(dataset):
17+
tf_record_pattern = os.path.join(FLAGS.data_dir, '%s-*' % dataset)
1818
data_files = tf.gfile.Glob(tf_record_pattern)
1919
return data_files
2020

2121
def run_training():
22-
#tf.reset_default_graph()
23-
data_files_ = TRAIN_FILE
22+
#data_files_ = TRAIN_FILE
2423
#data_files_ = VALIDATION_FILE
25-
#data_files_ = data_files()
24+
data_files_ = data_files(FLAGS.train_or_validation)
2625
images, labels = image_processing.distorted_inputs(
27-
[data_files_], FLAGS.num_epochs, batch_size=FLAGS.batch_size)
26+
data_files_, FLAGS.num_epochs, batch_size=FLAGS.batch_size)
27+
2828
labels = tf.one_hot(labels, 1000)
2929
logits = inference(images)
3030
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
@@ -57,12 +57,9 @@ def run_training():
5757
_, loss_value, pred, acc = sess.run(
5858
[train_op, loss, correct_pred, accuracy])
5959
duration = time.time() - start_batch
60-
if step % 10 == 0:
60+
if step % 100 == 0:
6161
print('Step %d | loss = %.2f | accuracy = %.2f (%.3f sec/batch)')%(
6262
step, loss_value, acc, duration)
63-
if step % 500 == 0:
64-
summary = sess.run(merged_summary_op)
65-
summary_writer.add_summary(summary, step*FLAGS.batch_size)
6663
if step % 5000 == 0:
6764
saver.save(sess, FLAGS.model_path)
6865

@@ -75,8 +72,58 @@ def run_training():
7572
coord.join(threads)
7673
sess.close()
7774

75+
def evaluation():
76+
#data_files_ = TRAIN_FILE
77+
data_files_ = data_files(FLAGS.train_or_validation)
78+
images, labels = image_processing.inputs(
79+
data_files_, FLAGS.num_epochs, batch_size=FLAGS.batch_size)
80+
81+
labels = tf.one_hot(labels, 1000)
82+
logits = inference(images)
83+
correct_pred = tf.equal(tf.arg_max(logits,1), tf.argmax(labels,1))
84+
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
85+
86+
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
87+
sess = tf.Session()
88+
sess.run(init_op)
89+
coord = tf.train.Coordinator()
90+
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
91+
92+
#save/restore model
93+
d={}
94+
l = ['w1', 'b1', 'w2', 'b2', 'w3', 'b3', 'w4', 'b4', 'w5', 'b5', 'w_fc1', 'b_fc1', 'w_fc2', 'b_fc2', 'w_output', 'b_output']
95+
for i in l:
96+
d[i] = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if v.name == i+':0'][0]
97+
saver = tf.train.Saver(d)
98+
saver.restore(sess, FLAGS.model_path)
99+
100+
try:
101+
step = 0
102+
start_time = time.time()
103+
while not coord.should_stop():
104+
start_batch = time.time()
105+
acc = sess.run(accuracy)
106+
duration = time.time() - start_batch
107+
print('Step %d | accuracy = %.2f (%.3f sec/batch)')%(
108+
step, acc, duration)
109+
step +=1
110+
except tf.errors.OutOfRangeError:
111+
print('Done evaluating for %d epochs, %d steps, %.1f min.' % (FLAGS.num_epochs, step, (time.time()-start_time)/60))
112+
finally:
113+
coord.request_stop()
114+
115+
coord.join(threads)
116+
sess.close()
117+
78118
def main(_):
79-
run_training()
119+
if FLAGS.train_or_validation == 'train':
120+
print ' *** run training.'
121+
print FLAGS.train_or_validation
122+
run_training()
123+
else:
124+
print ' *** run validation.'
125+
print FLAGS.train_or_validation
126+
evaluation()
80127

81128
if __name__ == '__main__':
82129
parser = argparse.ArgumentParser()
@@ -117,7 +164,7 @@ def main(_):
117164
help='Variables for the model.'
118165
)
119166
parser.add_argument(
120-
'--train',
167+
'--train_or_validation',
121168
type=str,
122169
default='train',
123170
help='Train or evaluate the dataset'

0 commit comments

Comments
 (0)