Skip to content

Commit 489eab0

Browse files
committed
add a function of visualizing learned manifold for VAE
Additionally 1. correct minor bugs for ACGAN and infoGAN 2. add dimension of z to input-argument 3. save the result images under the folder with name of the simulation settings 4. change batchnorm function as it in tf.contrib.layers
1 parent 4c4f18a commit 489eab0

22 files changed

+230
-132
lines changed

ACGAN.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from utils import *
1010

1111
class ACGAN(object):
12-
def __init__(self, sess, epoch, batch_size, dataset_name, checkpoint_dir, result_dir, log_dir):
12+
def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir):
1313
self.sess = sess
1414
self.dataset_name = dataset_name
1515
self.checkpoint_dir = checkpoint_dir
@@ -26,7 +26,7 @@ def __init__(self, sess, epoch, batch_size, dataset_name, checkpoint_dir, result
2626
self.output_height = 28
2727
self.output_width = 28
2828

29-
self.z_dim = 62 # dimension of noise-vector
29+
self.z_dim = z_dim # dimension of noise-vector
3030
self.y_dim = 10 # dimension of code-vector (label)
3131
self.c_dim = 1
3232

@@ -148,7 +148,7 @@ def build_model(self):
148148
t_vars = tf.trainable_variables()
149149
d_vars = [var for var in t_vars if 'd_' in var.name]
150150
g_vars = [var for var in t_vars if 'g_' in var.name]
151-
q_vars = [var for var in t_vars if 'd_' or 'c_' or 'g_' in var.name]
151+
q_vars = [var for var in t_vars if ('d_' in var.name) or ('c_' in var.name) or ('g_' in var.name)]
152152

153153
# optimizers
154154
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
@@ -246,8 +246,9 @@ def train(self):
246246
tot_num_samples = min(self.sample_num, self.batch_size)
247247
manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
248248
manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
249-
save_images(samples[:manifold_h*manifold_w,:,:,:], [manifold_h, manifold_w],
250-
'./'+self.result_dir+'/'+self.model_name+'_train_{:02d}_{:04d}.png'.format(epoch, idx))
249+
save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w], './' + check_folder(
250+
self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(
251+
epoch, idx))
251252

252253
# After an epoch, start_batch_id is set to zero
253254
# non-zero value is only for the first epoch after loading pre-trained model
@@ -275,7 +276,7 @@ def visualize_results(self, epoch):
275276
samples = self.sess.run(self.fake_images, feed_dict={self.inputs:self.test_images, self.z: z_sample, self.y: y_one_hot})
276277

277278
save_images(samples[:image_frame_dim*image_frame_dim,:,:,:], [image_frame_dim, image_frame_dim],
278-
self.result_dir + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png')
279+
check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png')
279280

280281
""" specified condition, random noise """
281282
n_styles = 10 # must be less than or equal to self.batch_size
@@ -290,7 +291,7 @@ def visualize_results(self, epoch):
290291

291292
samples = self.sess.run(self.fake_images, feed_dict={self.inputs:self.test_images, self.z: z_sample, self.y: y_one_hot})
292293
save_images(samples[:image_frame_dim*image_frame_dim,:,:,:], [image_frame_dim, image_frame_dim],
293-
self.result_dir + '/' + self.model_name + '_epoch%03d' % epoch + '_test_class_%d.png' % l)
294+
check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_class_%d.png' % l)
294295

295296
samples = samples[si, :, :, :]
296297

@@ -306,13 +307,13 @@ def visualize_results(self, epoch):
306307
canvas[s * self.len_discrete_code + c, :, :, :] = all_samples[c * n_styles + s, :, :, :]
307308

308309
save_images(canvas, [n_styles, self.len_discrete_code],
309-
self.result_dir + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes_style_by_style.png')
310+
check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes_style_by_style.png')
310311

311312
@property
312313
def model_dir(self):
313314
return "{}_{}_{}_{}".format(
314-
self.dataset_name, self.batch_size,
315-
self.output_height, self.output_width)
315+
self.model_name, self.dataset_name,
316+
self.batch_size, self.z_dim)
316317

317318
def save(self, checkpoint_dir, step):
318319
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)

BEGAN.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from utils import *
1010

1111
class BEGAN(object):
12-
def __init__(self, sess, batch_size, epoch, dataset_name, checkpoint_dir, result_dir, log_dir):
12+
def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir):
1313
self.sess = sess
1414
self.dataset_name = dataset_name
1515
self.checkpoint_dir = checkpoint_dir
@@ -26,7 +26,7 @@ def __init__(self, sess, batch_size, epoch, dataset_name, checkpoint_dir, result
2626
self.output_height = 28
2727
self.output_width = 28
2828

29-
self.z_dim = 62 # dimension of noise-vector
29+
self.z_dim = z_dim # dimension of noise-vector
3030
self.c_dim = 1
3131

3232
# BEGAN Parameter
@@ -208,7 +208,7 @@ def train(self):
208208
manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
209209
manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
210210
save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w],
211-
'./' + self.result_dir + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(
211+
'./' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(
212212
epoch, idx))
213213

214214
# After an epoch, start_batch_id is set to zero
@@ -235,13 +235,13 @@ def visualize_results(self, epoch):
235235
samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})
236236

237237
save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
238-
self.result_dir + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png')
238+
check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png')
239239

240240
@property
241241
def model_dir(self):
242242
return "{}_{}_{}_{}".format(
243-
self.dataset_name, self.batch_size,
244-
self.output_height, self.output_width)
243+
self.model_name, self.dataset_name,
244+
self.batch_size, self.z_dim)
245245

246246
def save(self, checkpoint_dir, step):
247247
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)

CGAN.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from utils import *
1010

1111
class CGAN(object):
12-
def __init__(self, sess, epoch, batch_size, dataset_name, checkpoint_dir, result_dir, log_dir):
12+
def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir):
1313
self.sess = sess
1414
self.dataset_name = dataset_name
1515
self.checkpoint_dir = checkpoint_dir
@@ -26,7 +26,7 @@ def __init__(self, sess, epoch, batch_size, dataset_name, checkpoint_dir, result
2626
self.output_height = 28
2727
self.output_width = 28
2828

29-
self.z_dim = 62 # dimension of noise-vector
29+
self.z_dim = z_dim # dimension of noise-vector
3030
self.y_dim = 10 # dimension of condition-vector (label)
3131
self.c_dim = 1
3232

@@ -210,7 +210,7 @@ def train(self):
210210
manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
211211
manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
212212
save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w],
213-
'./' + self.result_dir + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(
213+
'./' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(
214214
epoch, idx))
215215

216216
# After an epoch, start_batch_id is set to zero
@@ -240,7 +240,7 @@ def visualize_results(self, epoch):
240240
samples = self.sess.run(self.fake_images, feed_dict={self.inputs:self.test_images, self.z: z_sample, self.y: y_one_hot})
241241

242242
save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
243-
self.result_dir + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png')
243+
check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png')
244244

245245
""" specified condition, random noise """
246246
n_styles = 10 # must be less than or equal to self.batch_size
@@ -255,7 +255,7 @@ def visualize_results(self, epoch):
255255

256256
samples = self.sess.run(self.fake_images, feed_dict={self.inputs:self.test_images, self.z: z_sample, self.y: y_one_hot})
257257
save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
258-
self.result_dir + '/' + self.model_name + '_epoch%03d' % epoch + '_test_class_%d.png' % l)
258+
check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_class_%d.png' % l)
259259

260260
samples = samples[si, :, :, :]
261261

@@ -271,13 +271,13 @@ def visualize_results(self, epoch):
271271
canvas[s * self.y_dim + c, :, :, :] = all_samples[c * n_styles + s, :, :, :]
272272

273273
save_images(canvas, [n_styles, self.y_dim],
274-
self.result_dir + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes_style_by_style.png')
274+
check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes_style_by_style.png')
275275

276276
@property
277277
def model_dir(self):
278278
return "{}_{}_{}_{}".format(
279-
self.dataset_name, self.batch_size,
280-
self.output_height, self.output_width)
279+
self.model_name, self.dataset_name,
280+
self.batch_size, self.z_dim)
281281

282282
def save(self, checkpoint_dir, step):
283283
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)

CVAE.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import prior_factory as prior
1212

1313
class CVAE(object):
14-
def __init__(self, sess, epoch, batch_size, dataset_name, checkpoint_dir, result_dir, log_dir):
14+
def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir):
1515
self.sess = sess
1616
self.dataset_name = dataset_name
1717
self.checkpoint_dir = checkpoint_dir
@@ -28,7 +28,7 @@ def __init__(self, sess, epoch, batch_size, dataset_name, checkpoint_dir, result
2828
self.output_height = 28
2929
self.output_width = 28
3030

31-
self.z_dim = 62 # dimension of noise-vector
31+
self.z_dim = z_dim # dimension of noise-vector
3232
self.y_dim = 10 # dimension of condition-vector (label)
3333
self.c_dim = 1
3434

@@ -186,7 +186,7 @@ def train(self):
186186
batch_labels = self.data_y[idx * self.batch_size:(idx + 1) * self.batch_size]
187187
batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]).astype(np.float32)
188188

189-
# update D network
189+
# update autoencoder
190190
_, summary_str, loss, nll_loss, kl_loss = self.sess.run([self.optim, self.merged_summary_op, self.loss, self.neg_loglikelihood, self.KL_divergence],
191191
feed_dict={self.inputs: batch_images, self.y: batch_labels, self.z: batch_z})
192192
self.writer.add_summary(summary_str, counter)
@@ -204,7 +204,7 @@ def train(self):
204204
manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
205205
manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
206206
save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w],
207-
'./' + self.result_dir + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(
207+
'./' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(
208208
epoch, idx))
209209

210210
# After an epoch, start_batch_id is set to zero
@@ -234,7 +234,7 @@ def visualize_results(self, epoch):
234234
samples = self.sess.run(self.fake_images, feed_dict={self.inputs:self.test_images, self.z: z_sample, self.y: y_one_hot})
235235

236236
save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
237-
self.result_dir + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png')
237+
check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png')
238238

239239
""" specified condition, random noise """
240240
n_styles = 10 # must be less than or equal to self.batch_size
@@ -249,7 +249,7 @@ def visualize_results(self, epoch):
249249

250250
samples = self.sess.run(self.fake_images, feed_dict={self.inputs:self.test_images, self.z: z_sample, self.y: y_one_hot})
251251
save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
252-
self.result_dir + '/' + self.model_name + '_epoch%03d' % epoch + '_test_class_%d.png' % l)
252+
check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_class_%d.png' % l)
253253

254254
samples = samples[si, :, :, :]
255255

@@ -265,13 +265,13 @@ def visualize_results(self, epoch):
265265
canvas[s * self.y_dim + c, :, :, :] = all_samples[c * n_styles + s, :, :, :]
266266

267267
save_images(canvas, [n_styles, self.y_dim],
268-
self.result_dir + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes_style_by_style.png')
268+
check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes_style_by_style.png')
269269

270270
@property
271271
def model_dir(self):
272272
return "{}_{}_{}_{}".format(
273-
self.dataset_name, self.batch_size,
274-
self.output_height, self.output_width)
273+
self.model_name, self.dataset_name,
274+
self.batch_size, self.z_dim)
275275

276276
def save(self, checkpoint_dir, step):
277277
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)

DRAGAN.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from utils import *
1010

1111
class DRAGAN(object):
12-
def __init__(self, sess, epoch, batch_size, dataset_name, checkpoint_dir, result_dir, log_dir):
12+
def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir):
1313
self.sess = sess
1414
self.dataset_name = dataset_name
1515
self.checkpoint_dir = checkpoint_dir
@@ -26,7 +26,7 @@ def __init__(self, sess, epoch, batch_size, dataset_name, checkpoint_dir, result
2626
self.output_height = 28
2727
self.output_width = 28
2828

29-
self.z_dim = 62 # dimension of noise-vector
29+
self.z_dim = z_dim # dimension of noise-vector
3030
self.c_dim = 1
3131

3232
# DRAGAN parameter
@@ -211,7 +211,7 @@ def train(self):
211211
manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
212212
manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
213213
save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w],
214-
'./' + self.result_dir + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(
214+
'./' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(
215215
epoch, idx))
216216

217217
# After an epoch, start_batch_id is set to zero
@@ -238,13 +238,13 @@ def visualize_results(self, epoch):
238238
samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})
239239

240240
save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
241-
self.result_dir + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png')
241+
check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png')
242242

243243
@property
244244
def model_dir(self):
245245
return "{}_{}_{}_{}".format(
246-
self.dataset_name, self.batch_size,
247-
self.output_height, self.output_width)
246+
self.model_name, self.dataset_name,
247+
self.batch_size, self.z_dim)
248248

249249
def save(self, checkpoint_dir, step):
250250
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)

EBGAN.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from utils import *
1010

1111
class EBGAN(object):
12-
def __init__(self, sess, batch_size, epoch, dataset_name, checkpoint_dir, result_dir, log_dir):
12+
def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir):
1313
self.sess = sess
1414
self.dataset_name = dataset_name
1515
self.checkpoint_dir = checkpoint_dir
@@ -26,7 +26,7 @@ def __init__(self, sess, batch_size, epoch, dataset_name, checkpoint_dir, result
2626
self.output_height = 28
2727
self.output_width = 28
2828

29-
self.z_dim = 62 # dimension of noise-vector
29+
self.z_dim = z_dim # dimension of noise-vector
3030
self.c_dim = 1
3131

3232
# EBGAN Parameter
@@ -208,7 +208,7 @@ def train(self):
208208
manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
209209
manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
210210
save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w],
211-
'./' + self.result_dir + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(
211+
'./' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(
212212
epoch, idx))
213213

214214
# After an epoch, start_batch_id is set to zero
@@ -235,13 +235,13 @@ def visualize_results(self, epoch):
235235
samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})
236236

237237
save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
238-
self.result_dir + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png')
238+
check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png')
239239

240240
@property
241241
def model_dir(self):
242242
return "{}_{}_{}_{}".format(
243-
self.dataset_name, self.batch_size,
244-
self.output_height, self.output_width)
243+
self.model_name, self.dataset_name,
244+
self.batch_size, self.z_dim)
245245

246246
def save(self, checkpoint_dir, step):
247247
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)

0 commit comments

Comments
 (0)