Skip to content

Commit 34e4b1f

Browse files
committed
Add graph tracing
1 parent a2d2ce0 commit 34e4b1f

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

train.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,8 @@ def compute_loss(y_true, y_pred):
164164
opt.niter = 1
165165

166166
# Set up a log directory
167-
file_writer = tf.summary.create_file_writer(str(opt.log_dir/datetime.now().strftime('%Y%m%d-%H%M%S')))
167+
log_dir = str(opt.log_dir/datetime.now().strftime('%Y%m%d-%H%M%S'))
168+
file_writer = tf.summary.create_file_writer(log_dir)
168169

169170
# Set up a sample output directory
170171
opt.outf.mkdir(parents=True, exist_ok=True)
@@ -230,18 +231,26 @@ def distributed_train_step(dist_inputs):
230231
strategy.reduce(tf.distribute.ReduceOp.MEAN, D_G_z2, axis=None)
231232

232233

234+
# Trace graphs
235+
if not ckpt_manager.latest_checkpoint:
236+
tf.summary.trace_on(graph=True, profiler=True)
237+
data = next(iter(dataset_dist))
238+
distributed_train_step(data)
239+
with file_writer.as_default():
240+
tf.summary.trace_export('DCGAN_trace', step=0, profiler_outdir=log_dir)
241+
233242
for epoch in range(int(ckpt.epoch.numpy()), opt.niter + 1):
234243
for i, data in enumerate(dataset_dist):
235244
errD, errG, D_x, D_G_z1, D_G_z2 = distributed_train_step(data)
245+
236246
# Output training stats
237247
if i % 50 == 0:
238248
print(f'[{epoch}/{opt.niter}][{i}/{len(dataset)}]\t'
239249
f'Loss_D: {errD:.4f}\t'
240250
f'Loss_G: {errG:.4f}\t'
241251
f'D(x): {D_x:.4f}\t'
242252
f'D(G(z)): {D_G_z1:.4f} / {D_G_z2:.4f}')
243-
if opt.dry_run:
244-
break
253+
245254
# Log training stats
246255
ckpt.step.assign_add(1)
247256
step = int(ckpt.step.numpy())
@@ -251,8 +260,10 @@ def distributed_train_step(dist_inputs):
251260
tf.summary.scalar('D_x', D_x, step=step)
252261
tf.summary.scalar('D_G_z1', D_G_z1, step=step)
253262
tf.summary.scalar('D_G_z2', D_G_z2, step=step)
254-
if opt.dry_run:
255-
break
263+
264+
if opt.dry_run:
265+
break
266+
256267
# Check how the generator is doing by saving G's output on fixed_noise
257268
fake = netG(fixed_noise, training=False)
258269
# Scale it back to [0, 1]

0 commit comments

Comments
 (0)