Skip to content

Commit d26ea87

Browse files
committed
Add backward pass
1 parent 4a3f789 commit d26ea87

File tree

7 files changed

+30
-3
lines changed

7 files changed

+30
-3
lines changed

cuda_rasterizer/backward.cu

+17-1
Original file line numberDiff line numberDiff line change
@@ -406,9 +406,11 @@ renderCUDA(
406406
const float2* __restrict__ points_xy_image,
407407
const float4* __restrict__ conic_opacity,
408408
const float* __restrict__ colors,
409+
const float* __restrict__ depths,
409410
const float* __restrict__ final_Ts,
410411
const uint32_t* __restrict__ n_contrib,
411412
const float* __restrict__ dL_dpixels,
413+
const float* __restrict__ dL_depths,
412414
float3* __restrict__ dL_dmean2D,
413415
float4* __restrict__ dL_dconic2D,
414416
float* __restrict__ dL_dopacity,
@@ -435,6 +437,7 @@ renderCUDA(
435437
__shared__ float2 collected_xy[BLOCK_SIZE];
436438
__shared__ float4 collected_conic_opacity[BLOCK_SIZE];
437439
__shared__ float collected_colors[C * BLOCK_SIZE];
440+
__shared__ float collected_depths[BLOCK_SIZE];
438441

439442
// In the forward, we stored the final value for T, the
440443
// product of all (1 - alpha) factors.
@@ -448,12 +451,16 @@ renderCUDA(
448451

449452
float accum_rec[C] = { 0 };
450453
float dL_dpixel[C];
454+
float dL_depth;
455+
float accum_depth_rec = 0;
451456
if (inside)
452457
for (int i = 0; i < C; i++)
453458
dL_dpixel[i] = dL_dpixels[i * H * W + pix_id];
459+
dL_depth = dL_depths[pix_id];
454460

455461
float last_alpha = 0;
456462
float last_color[C] = { 0 };
463+
float last_depth = 0;
457464

458465
// Gradient of pixel coordinate w.r.t. normalized
459466
// screen-space viewport corrdinates (-1 to 1)
@@ -475,6 +482,7 @@ renderCUDA(
475482
collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id];
476483
for (int i = 0; i < C; i++)
477484
collected_colors[i * BLOCK_SIZE + block.thread_rank()] = colors[coll_id * C + i];
485+
collected_depths[block.thread_rank()] = depths[coll_id];
478486
}
479487
block.sync();
480488

@@ -522,6 +530,10 @@ renderCUDA(
522530
// many that were affected by this Gaussian.
523531
atomicAdd(&(dL_dcolors[global_id * C + ch]), dchannel_dcolor * dL_dchannel);
524532
}
533+
const float c_d = collected_depths[j];
534+
accum_depth_rec = last_alpha * last_depth + (1.f - last_alpha) * accum_depth_rec;
535+
last_depth = c_d;
536+
dL_dalpha += (c_d - accum_depth_rec) * dL_depth;
525537
dL_dalpha *= T;
526538
// Update last alpha (to be used in the next iteration)
527539
last_alpha = alpha;
@@ -630,9 +642,11 @@ void BACKWARD::render(
630642
const float2* means2D,
631643
const float4* conic_opacity,
632644
const float* colors,
645+
const float* depths,
633646
const float* final_Ts,
634647
const uint32_t* n_contrib,
635648
const float* dL_dpixels,
649+
const float* dL_depths,
636650
float3* dL_dmean2D,
637651
float4* dL_dconic2D,
638652
float* dL_dopacity,
@@ -646,12 +660,14 @@ void BACKWARD::render(
646660
means2D,
647661
conic_opacity,
648662
colors,
663+
depths,
649664
final_Ts,
650665
n_contrib,
651666
dL_dpixels,
667+
dL_depths,
652668
dL_dmean2D,
653669
dL_dconic2D,
654670
dL_dopacity,
655671
dL_dcolors
656672
);
657-
}
673+
}

cuda_rasterizer/backward.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@ namespace BACKWARD
2929
const float2* means2D,
3030
const float4* conic_opacity,
3131
const float* colors,
32+
const float* depths,
3233
const float* final_Ts,
3334
const uint32_t* n_contrib,
3435
const float* dL_dpixels,
36+
const float* dL_depths,
3537
float3* dL_dmean2D,
3638
float4* dL_dconic2D,
3739
float* dL_dopacity,
@@ -62,4 +64,4 @@ namespace BACKWARD
6264
glm::vec4* dL_drot);
6365
}
6466

65-
#endif
67+
#endif

cuda_rasterizer/rasterizer.h

+1
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ namespace CudaRasterizer
7373
char* binning_buffer,
7474
char* image_buffer,
7575
const float* dL_dpix,
76+
const float* dL_depths,
7677
float* dL_dmean2D,
7778
float* dL_dconic,
7879
float* dL_dopacity,

cuda_rasterizer/rasterizer_impl.cu

+4
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@ void CudaRasterizer::Rasterizer::backward(
360360
char* binning_buffer,
361361
char* img_buffer,
362362
const float* dL_dpix,
363+
const float* dL_depths,
363364
float* dL_dmean2D,
364365
float* dL_dconic,
365366
float* dL_dopacity,
@@ -390,6 +391,7 @@ void CudaRasterizer::Rasterizer::backward(
390391
// opacity and RGB of Gaussians from per-pixel loss gradients.
391392
// If we were given precomputed colors and not SHs, use them.
392393
const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp : geomState.rgb;
394+
const float* depth_ptr = geomState.depths;
393395
CHECK_CUDA(BACKWARD::render(
394396
tile_grid,
395397
block,
@@ -400,9 +402,11 @@ void CudaRasterizer::Rasterizer::backward(
400402
geomState.means2D,
401403
geomState.conic_opacity,
402404
color_ptr,
405+
depth_ptr,
403406
imgState.accum_alpha,
404407
imgState.n_contrib,
405408
dL_dpix,
409+
dL_depths,
406410
(float3*)dL_dmean2D,
407411
(float4*)dL_dconic,
408412
dL_dopacity,

diff_gaussian_rasterization/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ def backward(ctx, grad_out_color, grad_radii, grad_depth):
118118
raster_settings.projmatrix,
119119
raster_settings.tanfovx,
120120
raster_settings.tanfovy,
121-
grad_out_color,
121+
grad_out_color,
122+
grad_depth,
122123
sh,
123124
raster_settings.sh_degree,
124125
raster_settings.campos,

rasterize_points.cu

+2
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
131131
const float tan_fovx,
132132
const float tan_fovy,
133133
const torch::Tensor& dL_dout_color,
134+
const torch::Tensor& dL_dout_depth,
134135
const torch::Tensor& sh,
135136
const int degree,
136137
const torch::Tensor& campos,
@@ -182,6 +183,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
182183
reinterpret_cast<char*>(binningBuffer.contiguous().data_ptr()),
183184
reinterpret_cast<char*>(imageBuffer.contiguous().data_ptr()),
184185
dL_dout_color.contiguous().data<float>(),
186+
dL_dout_depth.contiguous().data<float>(),
185187
dL_dmeans2D.contiguous().data<float>(),
186188
dL_dconic.contiguous().data<float>(),
187189
dL_dopacity.contiguous().data<float>(),

rasterize_points.h

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
5252
const float tan_fovx,
5353
const float tan_fovy,
5454
const torch::Tensor& dL_dout_color,
55+
const torch::Tensor& dL_dout_depth,
5556
const torch::Tensor& sh,
5657
const int degree,
5758
const torch::Tensor& campos,

0 commit comments

Comments
 (0)