diff --git a/mmdet3d/structures/ops/transforms.py b/mmdet3d/structures/ops/transforms.py index 8e9f7006ac..1a698d9f4c 100644 --- a/mmdet3d/structures/ops/transforms.py +++ b/mmdet3d/structures/ops/transforms.py @@ -37,11 +37,12 @@ def bbox3d2roi(bbox_list): """ rois_list = [] for img_id, bboxes in enumerate(bbox_list): + img_inds = bboxes.new_full((bboxes.size(0), 1), img_id) if bboxes.size(0) > 0: - img_inds = bboxes.new_full((bboxes.size(0), 1), img_id) rois = torch.cat([img_inds, bboxes], dim=-1) else: rois = torch.zeros_like(bboxes) + rois = torch.cat([img_inds, rois], dim=-1) rois_list.append(rois) rois = torch.cat(rois_list, 0) return rois