Skip to content

Commit 9994002

Browse files
committed
bug fixed for multiple GPUs
1 parent 9c5c202 commit 9994002

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

rasterize_points.cu

+4-5
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,10 @@ RasterizeGaussiansCUDA(
7575

7676
torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32));
7777

78-
torch::Device device(torch::kCUDA);
79-
torch::TensorOptions options(torch::kByte);
80-
torch::Tensor geomBuffer = torch::empty({0}, options.device(device));
81-
torch::Tensor binningBuffer = torch::empty({0}, options.device(device));
82-
torch::Tensor imgBuffer = torch::empty({0}, options.device(device));
78+
torch::TensorOptions byte_opts = means3D.options().dtype(torch::kByte);
79+
torch::Tensor geomBuffer = torch::empty({0}, byte_opts);
80+
torch::Tensor binningBuffer = torch::empty({0}, byte_opts);
81+
torch::Tensor imgBuffer = torch::empty({0}, byte_opts);
8382
std::function<char*(size_t)> geomFunc = resizeFunctional(geomBuffer);
8483
std::function<char*(size_t)> binningFunc = resizeFunctional(binningBuffer);
8584
std::function<char*(size_t)> imgFunc = resizeFunctional(imgBuffer);

0 commit comments

Comments
 (0)