Skip to content

Commit 4a3f789

Browse files
committed
Add depth forward pass
1 parent df36a86 commit 4a3f789

File tree

7 files changed

+35
-18
lines changed

7 files changed

+35
-18
lines changed

cuda_rasterizer/forward.cu

+13-4
Original file line numberDiff line numberDiff line change
@@ -266,11 +266,13 @@ renderCUDA(
266266
int W, int H,
267267
const float2* __restrict__ points_xy_image,
268268
const float* __restrict__ features,
269+
const float* __restrict__ depths,
269270
const float4* __restrict__ conic_opacity,
270271
float* __restrict__ final_T,
271272
uint32_t* __restrict__ n_contrib,
272273
const float* __restrict__ bg_color,
273-
float* __restrict__ out_color)
274+
float* __restrict__ out_color,
275+
float* __restrict__ out_depth)
274276
{
275277
// Identify current tile and associated min/max pixel range.
276278
auto block = cg::this_thread_block();
@@ -301,6 +303,7 @@ renderCUDA(
301303
uint32_t contributor = 0;
302304
uint32_t last_contributor = 0;
303305
float C[CHANNELS] = { 0 };
306+
float D = { 0 };
304307

305308
// Iterate over batches until all done or range is complete
306309
for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE)
@@ -353,6 +356,7 @@ renderCUDA(
353356
// Eq. (3) from 3D Gaussian splatting paper.
354357
for (int ch = 0; ch < CHANNELS; ch++)
355358
C[ch] += features[collected_id[j] * CHANNELS + ch] * alpha * T;
359+
D += depths[collected_id[j]] * alpha * T;
356360

357361
T = test_T;
358362

@@ -370,6 +374,7 @@ renderCUDA(
370374
n_contrib[pix_id] = last_contributor;
371375
for (int ch = 0; ch < CHANNELS; ch++)
372376
out_color[ch * H * W + pix_id] = C[ch] + T * bg_color[ch];
377+
out_depth[pix_id] = D;
373378
}
374379
}
375380

@@ -380,23 +385,27 @@ void FORWARD::render(
380385
int W, int H,
381386
const float2* means2D,
382387
const float* colors,
388+
const float* depths,
383389
const float4* conic_opacity,
384390
float* final_T,
385391
uint32_t* n_contrib,
386392
const float* bg_color,
387-
float* out_color)
393+
float* out_color,
394+
float* out_depth)
388395
{
389396
renderCUDA<NUM_CHANNELS> << <grid, block >> > (
390397
ranges,
391398
point_list,
392399
W, H,
393400
means2D,
394401
colors,
402+
depths,
395403
conic_opacity,
396404
final_T,
397405
n_contrib,
398406
bg_color,
399-
out_color);
407+
out_color,
408+
out_depth);
400409
}
401410

402411
void FORWARD::preprocess(int P, int D, int M,
@@ -452,4 +461,4 @@ void FORWARD::preprocess(int P, int D, int M,
452461
tiles_touched,
453462
prefiltered
454463
);
455-
}
464+
}

cuda_rasterizer/forward.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,14 @@ namespace FORWARD
5555
int W, int H,
5656
const float2* points_xy_image,
5757
const float* features,
58+
const float* depths,
5859
const float4* conic_opacity,
5960
float* final_T,
6061
uint32_t* n_contrib,
6162
const float* bg_color,
62-
float* out_color);
63+
float* out_color,
64+
float* out_depth);
6365
}
6466

6567

66-
#endif
68+
#endif

cuda_rasterizer/rasterizer.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ namespace CudaRasterizer
4949
const float tan_fovx, float tan_fovy,
5050
const bool prefiltered,
5151
float* out_color,
52+
float* out_depth,
5253
int* radii = nullptr,
5354
bool debug = false);
5455

@@ -85,4 +86,4 @@ namespace CudaRasterizer
8586
};
8687
};
8788

88-
#endif
89+
#endif

cuda_rasterizer/rasterizer_impl.cu

+5-2
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ int CudaRasterizer::Rasterizer::forward(
216216
const float tan_fovx, float tan_fovy,
217217
const bool prefiltered,
218218
float* out_color,
219+
float* out_depth,
219220
int* radii,
220221
bool debug)
221222
{
@@ -326,11 +327,13 @@ int CudaRasterizer::Rasterizer::forward(
326327
width, height,
327328
geomState.means2D,
328329
feature_ptr,
330+
geomState.depths,
329331
geomState.conic_opacity,
330332
imgState.accum_alpha,
331333
imgState.n_contrib,
332334
background,
333-
out_color), debug)
335+
out_color,
336+
out_depth), debug)
334337

335338
return num_rendered;
336339
}
@@ -431,4 +434,4 @@ void CudaRasterizer::Rasterizer::backward(
431434
dL_dsh,
432435
(glm::vec3*)dL_dscale,
433436
(glm::vec4*)dL_drot), debug)
434-
}
437+
}

diff_gaussian_rasterization/__init__.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -83,22 +83,22 @@ def forward(
8383
if raster_settings.debug:
8484
cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
8585
try:
86-
num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
86+
num_rendered, color, depth, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
8787
except Exception as ex:
8888
torch.save(cpu_args, "snapshot_fw.dump")
8989
print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.")
9090
raise ex
9191
else:
92-
num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
92+
num_rendered, color, depth, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
9393

9494
# Keep relevant tensors for backward
9595
ctx.raster_settings = raster_settings
9696
ctx.num_rendered = num_rendered
9797
ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer)
98-
return color, radii
98+
return color, radii, depth
9999

100100
@staticmethod
101-
def backward(ctx, grad_out_color, _):
101+
def backward(ctx, grad_out_color, grad_radii, grad_depth):
102102

103103
# Restore necessary values from context
104104
num_rendered = ctx.num_rendered

rasterize_points.cu

+5-3
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ std::function<char*(size_t N)> resizeFunctional(torch::Tensor& t) {
3232
return lambda;
3333
}
3434

35-
std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
35+
std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
3636
RasterizeGaussiansCUDA(
3737
const torch::Tensor& background,
3838
const torch::Tensor& means3D,
@@ -66,6 +66,7 @@ RasterizeGaussiansCUDA(
6666
auto float_opts = means3D.options().dtype(torch::kFloat32);
6767

6868
torch::Tensor out_color = torch::full({NUM_CHANNELS, H, W}, 0.0, float_opts);
69+
torch::Tensor out_depth = torch::full({1, H, W}, 0.0, float_opts);
6970
torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32));
7071

7172
torch::Device device(torch::kCUDA);
@@ -108,10 +109,11 @@ RasterizeGaussiansCUDA(
108109
tan_fovy,
109110
prefiltered,
110111
out_color.contiguous().data<float>(),
112+
out_depth.contiguous().data<float>(),
111113
radii.contiguous().data<int>(),
112114
debug);
113115
}
114-
return std::make_tuple(rendered, out_color, radii, geomBuffer, binningBuffer, imgBuffer);
116+
return std::make_tuple(rendered, out_color, out_depth, radii, geomBuffer, binningBuffer, imgBuffer);
115117
}
116118

117119
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
@@ -214,4 +216,4 @@ torch::Tensor markVisible(
214216
}
215217

216218
return present;
217-
}
219+
}

rasterize_points.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#include <tuple>
1616
#include <string>
1717

18-
std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
18+
std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
1919
RasterizeGaussiansCUDA(
2020
const torch::Tensor& background,
2121
const torch::Tensor& means3D,
@@ -64,4 +64,4 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
6464
torch::Tensor markVisible(
6565
torch::Tensor& means3D,
6666
torch::Tensor& viewmatrix,
67-
torch::Tensor& projmatrix);
67+
torch::Tensor& projmatrix);

0 commit comments

Comments
 (0)