in depth_upsampling/models/msg/msg.py [0:0]
def forward(self, batch):
rgb_img = batch[dataset_keys.COLOR_IMG] / 255
low_res_depth = batch[dataset_keys.LOW_RES_DEPTH_IMG]
min_d = low_res_depth.amin((1, 2, 3), keepdim=True)
max_d = low_res_depth.amax((1, 2, 3), keepdim=True)
low_res_depth_norm = (low_res_depth - min_d) / ((max_d - min_d) + 1e-8)
low_res_upsampled = F.interpolate(low_res_depth_norm, rgb_img.shape[2:], mode='bicubic')
rgb_features = [self.rgb_encoder1(rgb_img), ]
for block in self.rgb_encoder_blocks:
rgb_features.append(block(rgb_features[-1]))
rec = self.depth_decoder1(low_res_depth_norm)
for i, block in enumerate(self.depth_decoder_blocks):
rec = torch.cat((rec, rgb_features[-(i + 1)]), 1)
rec = block(rec)
rec = torch.cat((rec, rgb_features[0]), 1)
rec = self.depth_decoder_n(rec)
output = (low_res_upsampled + rec) * (max_d - min_d) + min_d
return {dataset_keys.PREDICTION_DEPTH_IMG: output}