From 9994002667d2f3bb2f2535286340bbce69f4bcb1 Mon Sep 17 00:00:00 2001 From: yindaheng98 Date: Sun, 12 Jan 2025 21:19:00 -0800 Subject: [PATCH] bug fixed for multiple GPUs --- rasterize_points.cu | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/rasterize_points.cu b/rasterize_points.cu index e625c19e..2a66f631 100644 --- a/rasterize_points.cu +++ b/rasterize_points.cu @@ -75,11 +75,10 @@ RasterizeGaussiansCUDA( torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32)); - torch::Device device(torch::kCUDA); - torch::TensorOptions options(torch::kByte); - torch::Tensor geomBuffer = torch::empty({0}, options.device(device)); - torch::Tensor binningBuffer = torch::empty({0}, options.device(device)); - torch::Tensor imgBuffer = torch::empty({0}, options.device(device)); + torch::TensorOptions byte_opts = means3D.options().dtype(torch::kByte); + torch::Tensor geomBuffer = torch::empty({0}, byte_opts); + torch::Tensor binningBuffer = torch::empty({0}, byte_opts); + torch::Tensor imgBuffer = torch::empty({0}, byte_opts); std::function geomFunc = resizeFunctional(geomBuffer); std::function binningFunc = resizeFunctional(binningBuffer); std::function imgFunc = resizeFunctional(imgBuffer);