10
10
from convcnp .cnp import RegressionCNP as CNP
11
11
from convcnp .experiment import (
12
12
report_loss ,
13
- RunningAverage ,
14
13
generate_root ,
15
14
WorkingDirectory ,
16
15
save_checkpoint
21
20
22
21
def validate (data , model , report_freq = None ):
23
22
"""Compute the validation loss."""
24
- ravg = RunningAverage ()
25
23
model .eval ()
24
+ likelihoods = []
26
25
with torch .no_grad ():
27
26
for step , task in enumerate (data ):
27
+ num_target = task ['y_target' ].shape [1 ]
28
28
y_mean , y_std = \
29
29
model (task ['x_context' ], task ['y_context' ], task ['x_target' ])
30
30
obj = \
31
- - gaussian_logpdf (task ['y_target' ], y_mean , y_std ,
31
+ gaussian_logpdf (task ['y_target' ], y_mean , y_std ,
32
32
'batched_mean' )
33
- ravg . update (obj .item () / data . batch_size , data . batch_size )
33
+ likelihoods . append (obj .item () / num_target )
34
34
if report_freq :
35
- report_loss ('Validation' , ravg .avg , step , report_freq )
36
- return ravg .avg
35
+ avg_ll = np .array (likelihoods ).mean ()
36
+ report_loss ('Validation' , avg_ll , step , report_freq )
37
+ avg_ll = np .array (likelihoods ).mean ()
38
+ return avg_ll
37
39
38
40
39
41
def train (data , model , opt , report_freq ):
40
42
"""Perform a training epoch."""
41
- ravg = RunningAverage ()
42
43
model .train ()
44
+ losses = []
43
45
for step , task in enumerate (data ):
44
46
y_mean , y_std = model (task ['x_context' ], task ['y_context' ],
45
47
task ['x_target' ])
46
48
obj = - gaussian_logpdf (task ['y_target' ], y_mean , y_std , 'batched_mean' )
49
+
50
+ # Optimization
47
51
obj .backward ()
48
52
opt .step ()
49
53
opt .zero_grad ()
50
- ravg .update (obj .item () / data .batch_size , data .batch_size )
51
- report_loss ('Training' , ravg .avg , step , report_freq )
52
- return ravg .avg
54
+
55
+ # Track training progress
56
+ losses .append (obj .item ())
57
+ avg_loss = np .array (losses ).mean ()
58
+ report_loss ('Training' , avg_loss , step , report_freq )
59
+ return avg_loss
53
60
54
61
55
62
# Parse arguments given to the script.
@@ -141,7 +148,7 @@ def train(data, model, opt, report_freq):
141
148
weight_decay = args .weight_decay )
142
149
if args .train :
143
150
# Run the training loop, maintaining the best objective value.
144
- best_obj = np .inf
151
+ best_obj = - np .inf
145
152
for epoch in range (args .epochs ):
146
153
print ('\n Epoch: {}/{}' .format (epoch + 1 , args .epochs ))
147
154
@@ -155,7 +162,7 @@ def train(data, model, opt, report_freq):
155
162
156
163
# Update the best objective value and checkpoint the model.
157
164
is_best = False
158
- if val_obj < best_obj :
165
+ if val_obj > best_obj :
159
166
best_obj = val_obj
160
167
is_best = True
161
168
save_checkpoint (wd ,
@@ -172,7 +179,7 @@ def train(data, model, opt, report_freq):
172
179
173
180
# Finally, test model on ~2000 tasks.
174
181
test_obj = validate (gen_test , model )
175
- print ('Model averages a log-likelihood of %s on unseen tasks.' % - test_obj )
182
+ print ('Model averages a log-likelihood of %s on unseen tasks.' % test_obj )
176
183
with open (wd .file ('test_log_likelihood.txt' ), 'w' ) as f :
177
- f .write (str (- test_obj ))
184
+ f .write (str (test_obj ))
178
185
0 commit comments