diff --git a/cuda_rasterizer/backward.cu b/cuda_rasterizer/backward.cu index 4aa41e1c..d58f0683 100644 --- a/cuda_rasterizer/backward.cu +++ b/cuda_rasterizer/backward.cu @@ -530,7 +530,7 @@ renderCUDA( // the background color is added if nothing left to blend float bg_dot_dpixel = 0; for (int i = 0; i < C; i++) - bg_dot_dpixel += bg_color[i] * dL_dpixel[i]; + bg_dot_dpixel += bg_color[i + pix_id * C] * dL_dpixel[i]; dL_dalpha += (-T_final / (1.f - alpha)) * bg_dot_dpixel; diff --git a/cuda_rasterizer/forward.cu b/cuda_rasterizer/forward.cu index c419a328..369a7f31 100644 --- a/cuda_rasterizer/forward.cu +++ b/cuda_rasterizer/forward.cu @@ -369,7 +369,7 @@ renderCUDA( final_T[pix_id] = T; n_contrib[pix_id] = last_contributor; for (int ch = 0; ch < CHANNELS; ch++) - out_color[ch * H * W + pix_id] = C[ch] + T * bg_color[ch]; + out_color[ch * H * W + pix_id] = C[ch] + T * bg_color[ch + pix_id * CHANNELS]; } } diff --git a/diff_gaussian_rasterization/__init__.py b/diff_gaussian_rasterization/__init__.py index bbef37d1..f5a3fe7b 100644 --- a/diff_gaussian_rasterization/__init__.py +++ b/diff_gaussian_rasterization/__init__.py @@ -55,10 +55,14 @@ def forward( cov3Ds_precomp, raster_settings, ): + # Flatten from CxHxW + bg = raster_settings.bg if len(raster_settings.bg.shape) == 1 else raster_settings.bg.permute(1,2,0).flatten() + pixel_count = raster_settings.image_height * raster_settings.image_width + bg = bg if len(bg) > 3 else bg.repeat(pixel_count) # Restructure arguments the way that the C++ lib expects them args = ( - raster_settings.bg, + bg, means3D, colors_precomp, opacities, @@ -105,8 +109,13 @@ def backward(ctx, grad_out_color, _): raster_settings = ctx.raster_settings colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer = ctx.saved_tensors + # Flatten from CxHxW + bg = raster_settings.bg if len(raster_settings.bg.shape) == 1 else raster_settings.bg.permute(1,2,0).flatten() + pixel_count = raster_settings.image_height * raster_settings.image_width + bg = bg if len(bg) > 3 else bg.repeat(pixel_count) + # Restructure args as C++ method expects them - args = (raster_settings.bg, + args = (bg, means3D, radii, colors_precomp,