in transforms.py [0:0]
def forward(self, image: torch.Tensor):
C, H, W = image.shape
if C != 4:
err_msg = (
f"This transform is for 4 channel RGBD input only; got {image.shape}"
)
raise ValueError(err_msg)
color_img = image[:3, ...] # (3, H, W)
depth_img = image[3:4, ...] # (1, H, W)
# Clamp to 0.0 to prevent negative depth values
depth_img = depth_img.clamp(min=self.min_depth)
# divide by max_depth
if self.clamp_max_before_scale:
depth_img = depth_img.clamp(max=self.max_depth)
depth_img /= self.max_depth
img = torch.cat([color_img, depth_img], dim=0)
return img