Description
I'm trying to fit a distribution using Luz, but I'm getting an error when trying to predict. I took the original code from https://torch.mlverse.org/docs/articles/distributions.html?q=distributions
Following the example, everything works great when using the fitting in a loop (I switch the optimizer to adam to fit my actual use-case):
library(torch)
torch_manual_seed(1) # setting seed for reproducibility
x <- torch_randn(100, 1)
y <- 2*x + 1 + torch_randn(100, 1)
x_test <- torch_randn(50, 1)
y_test <- 2*x_test + 1
GaussianLinear <- nn_module(
initialize = function() {
# this linear predictor will estimate the mean of the normal distribution
self$linear <- nn_linear(1, 1)
# this parameter will hold the estimate of the variability
self$scale <- nn_parameter(torch_ones(1))
},
forward = function(x) {
# we estimate the mean
loc <- self$linear(x)
# return a normal distribution
distr_normal(loc, self$scale)
}
)
model <- GaussianLinear()
opt <- optim_adam(model$parameters, lr = 0.1)
for (i in 1:100) {
opt$zero_grad()
d <- model(x)
loss <- torch_mean(-d$log_prob(y))
loss$backward()
opt$step()
if (i %% 10 == 0)
cat("iter: ", i, " loss: ", loss$item(), "\n")
}
silly <- as.numeric(model(x_test)$mean)
plot(silly,as.numeric(y_test))
But when I imbed the loss function in the module and use setup %>% fit, everything seems to work:
GaussianLinear2 <- nn_module(
initialize = function() {
# this linear predictor will estimate the mean of the normal distribution
self$linear <- nn_linear(1, 1)
# this parameter will hold the estimate of the variability
self$scale <- nn_parameter(torch_ones(1))
},
forward = function(x) {
# we estimate the mean
loc <- self$linear(x)
# return a normal distribution
distr_normal(loc, self$scale)
},
loss = function(a,b) {
d <- ctx$model(ctx$input)
torch_mean(-d$log_prob(ctx$target))
}
)
TorchModel_gauss <- GaussianLinear2 %>%
setup(
optimizer = optim_adam
) %>%
fit(list(x,y), epochs = 100,verbose=TRUE)
But when I try to predict, I get an error even when the new data is already a torch_tensor:
silly <- predict(TorchModel_gauss,newdata=x_test,verbose=TRUE)
Error in (function (tensors, dim) : Expected a torch_tensor.
Any idea what I'm doing wrong? I tried converting x and y into a dataset using dataloader() but it didn't solve the issue.