Skip to content

Commit cee0e45

Browse files
committed
Fix bug in likelihood calculation
1 parent c021726 commit cee0e45

File tree

1 file changed

+21
-14
lines changed

1 file changed

+21
-14
lines changed

train.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from convcnp.cnp import RegressionCNP as CNP
1111
from convcnp.experiment import (
1212
report_loss,
13-
RunningAverage,
1413
generate_root,
1514
WorkingDirectory,
1615
save_checkpoint
@@ -21,35 +20,43 @@
2120

2221
def validate(data, model, report_freq=None):
2322
"""Compute the validation loss."""
24-
ravg = RunningAverage()
2523
model.eval()
24+
likelihoods = []
2625
with torch.no_grad():
2726
for step, task in enumerate(data):
27+
num_target = task['y_target'].shape[1]
2828
y_mean, y_std = \
2929
model(task['x_context'], task['y_context'], task['x_target'])
3030
obj = \
31-
-gaussian_logpdf(task['y_target'], y_mean, y_std,
31+
gaussian_logpdf(task['y_target'], y_mean, y_std,
3232
'batched_mean')
33-
ravg.update(obj.item() / data.batch_size, data.batch_size)
33+
likelihoods.append(obj.item() / num_target)
3434
if report_freq:
35-
report_loss('Validation', ravg.avg, step, report_freq)
36-
return ravg.avg
35+
avg_ll = np.array(likelihoods).mean()
36+
report_loss('Validation', avg_ll, step, report_freq)
37+
avg_ll = np.array(likelihoods).mean()
38+
return avg_ll
3739

3840

3941
def train(data, model, opt, report_freq):
4042
"""Perform a training epoch."""
41-
ravg = RunningAverage()
4243
model.train()
44+
losses = []
4345
for step, task in enumerate(data):
4446
y_mean, y_std = model(task['x_context'], task['y_context'],
4547
task['x_target'])
4648
obj = -gaussian_logpdf(task['y_target'], y_mean, y_std, 'batched_mean')
49+
50+
# Optimization
4751
obj.backward()
4852
opt.step()
4953
opt.zero_grad()
50-
ravg.update(obj.item() / data.batch_size, data.batch_size)
51-
report_loss('Training', ravg.avg, step, report_freq)
52-
return ravg.avg
54+
55+
# Track training progress
56+
losses.append(obj.item())
57+
avg_loss = np.array(losses).mean()
58+
report_loss('Training', avg_loss, step, report_freq)
59+
return avg_loss
5360

5461

5562
# Parse arguments given to the script.
@@ -141,7 +148,7 @@ def train(data, model, opt, report_freq):
141148
weight_decay=args.weight_decay)
142149
if args.train:
143150
# Run the training loop, maintaining the best objective value.
144-
best_obj = np.inf
151+
best_obj = -np.inf
145152
for epoch in range(args.epochs):
146153
print('\nEpoch: {}/{}'.format(epoch + 1, args.epochs))
147154

@@ -155,7 +162,7 @@ def train(data, model, opt, report_freq):
155162

156163
# Update the best objective value and checkpoint the model.
157164
is_best = False
158-
if val_obj < best_obj:
165+
if val_obj > best_obj:
159166
best_obj = val_obj
160167
is_best = True
161168
save_checkpoint(wd,
@@ -172,7 +179,7 @@ def train(data, model, opt, report_freq):
172179

173180
# Finally, test model on ~2000 tasks.
174181
test_obj = validate(gen_test, model)
175-
print('Model averages a log-likelihood of %s on unseen tasks.' % -test_obj)
182+
print('Model averages a log-likelihood of %s on unseen tasks.' % test_obj)
176183
with open(wd.file('test_log_likelihood.txt'), 'w') as f:
177-
f.write(str(-test_obj))
184+
f.write(str(test_obj))
178185

0 commit comments

Comments
 (0)