Skip to content

Error when using predict() - "Expected a torch_tensor" #151

Open
@andrewbcooper

Description

@andrewbcooper

I'm trying to fit a distribution using Luz, but I'm getting an error when trying to predict. I took the original code from https://torch.mlverse.org/docs/articles/distributions.html?q=distributions

Following the example, everything works great when using the fitting in a loop (I switch the optimizer to adam to fit my actual use-case):

library(torch)
torch_manual_seed(1) # setting seed for reproducibility
x <- torch_randn(100, 1)
y <- 2*x + 1 + torch_randn(100, 1)

x_test <- torch_randn(50, 1)
y_test <- 2*x_test + 1

GaussianLinear <- nn_module(
initialize = function() {
# this linear predictor will estimate the mean of the normal distribution
self$linear <- nn_linear(1, 1)
# this parameter will hold the estimate of the variability
self$scale <- nn_parameter(torch_ones(1))
},
forward = function(x) {
# we estimate the mean
loc <- self$linear(x)
# return a normal distribution
distr_normal(loc, self$scale)
}
)

model <- GaussianLinear()

opt <- optim_adam(model$parameters, lr = 0.1)

for (i in 1:100) {
opt$zero_grad()
d <- model(x)
loss <- torch_mean(-d$log_prob(y))
loss$backward()
opt$step()
if (i %% 10 == 0)
cat("iter: ", i, " loss: ", loss$item(), "\n")
}

silly <- as.numeric(model(x_test)$mean)
plot(silly,as.numeric(y_test))

But when I imbed the loss function in the module and use setup %>% fit, everything seems to work:

GaussianLinear2 <- nn_module(
initialize = function() {
# this linear predictor will estimate the mean of the normal distribution
self$linear <- nn_linear(1, 1)
# this parameter will hold the estimate of the variability
self$scale <- nn_parameter(torch_ones(1))
},
forward = function(x) {
# we estimate the mean
loc <- self$linear(x)
# return a normal distribution
distr_normal(loc, self$scale)
},
loss = function(a,b) {
d <- ctx$model(ctx$input)
torch_mean(-d$log_prob(ctx$target))
}
)

TorchModel_gauss <- GaussianLinear2 %>%
setup(
optimizer = optim_adam
) %>%
fit(list(x,y), epochs = 100,verbose=TRUE)

But when I try to predict, I get an error even when the new data is already a torch_tensor:

silly <- predict(TorchModel_gauss,newdata=x_test,verbose=TRUE)
Error in (function (tensors, dim) : Expected a torch_tensor.

Any idea what I'm doing wrong? I tried converting x and y into a dataset using dataloader() but it didn't solve the issue.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions