@@ -164,7 +164,8 @@ def compute_loss(y_true, y_pred):
164
164
opt .niter = 1
165
165
166
166
# Set up a log directory
167
- file_writer = tf .summary .create_file_writer (str (opt .log_dir / datetime .now ().strftime ('%Y%m%d-%H%M%S' )))
167
+ log_dir = str (opt .log_dir / datetime .now ().strftime ('%Y%m%d-%H%M%S' ))
168
+ file_writer = tf .summary .create_file_writer (log_dir )
168
169
169
170
# Set up a sample output directory
170
171
opt .outf .mkdir (parents = True , exist_ok = True )
@@ -230,18 +231,26 @@ def distributed_train_step(dist_inputs):
230
231
strategy .reduce (tf .distribute .ReduceOp .MEAN , D_G_z2 , axis = None )
231
232
232
233
234
+ # Trace graphs
235
+ if not ckpt_manager .latest_checkpoint :
236
+ tf .summary .trace_on (graph = True , profiler = True )
237
+ data = next (iter (dataset_dist ))
238
+ distributed_train_step (data )
239
+ with file_writer .as_default ():
240
+ tf .summary .trace_export ('DCGAN_trace' , step = 0 , profiler_outdir = log_dir )
241
+
233
242
for epoch in range (int (ckpt .epoch .numpy ()), opt .niter + 1 ):
234
243
for i , data in enumerate (dataset_dist ):
235
244
errD , errG , D_x , D_G_z1 , D_G_z2 = distributed_train_step (data )
245
+
236
246
# Output training stats
237
247
if i % 50 == 0 :
238
248
print (f'[{ epoch } /{ opt .niter } ][{ i } /{ len (dataset )} ]\t '
239
249
f'Loss_D: { errD :.4f} \t '
240
250
f'Loss_G: { errG :.4f} \t '
241
251
f'D(x): { D_x :.4f} \t '
242
252
f'D(G(z)): { D_G_z1 :.4f} / { D_G_z2 :.4f} ' )
243
- if opt .dry_run :
244
- break
253
+
245
254
# Log training stats
246
255
ckpt .step .assign_add (1 )
247
256
step = int (ckpt .step .numpy ())
@@ -251,8 +260,10 @@ def distributed_train_step(dist_inputs):
251
260
tf .summary .scalar ('D_x' , D_x , step = step )
252
261
tf .summary .scalar ('D_G_z1' , D_G_z1 , step = step )
253
262
tf .summary .scalar ('D_G_z2' , D_G_z2 , step = step )
254
- if opt .dry_run :
255
- break
263
+
264
+ if opt .dry_run :
265
+ break
266
+
256
267
# Check how the generator is doing by saving G's output on fixed_noise
257
268
fake = netG (fixed_noise , training = False )
258
269
# Scale it back to [0, 1]
0 commit comments