def forward()

in depth_upsampling/models/msg/msg.py [0:0]


    def forward(self, batch):
        rgb_img = batch[dataset_keys.COLOR_IMG] / 255
        low_res_depth = batch[dataset_keys.LOW_RES_DEPTH_IMG]
        min_d = low_res_depth.amin((1, 2, 3), keepdim=True)
        max_d = low_res_depth.amax((1, 2, 3), keepdim=True)
        low_res_depth_norm = (low_res_depth - min_d) / ((max_d - min_d) + 1e-8)
        low_res_upsampled = F.interpolate(low_res_depth_norm, rgb_img.shape[2:], mode='bicubic')

        rgb_features = [self.rgb_encoder1(rgb_img), ]
        for block in self.rgb_encoder_blocks:
            rgb_features.append(block(rgb_features[-1]))

        rec = self.depth_decoder1(low_res_depth_norm)
        for i, block in enumerate(self.depth_decoder_blocks):
            rec = torch.cat((rec, rgb_features[-(i + 1)]), 1)
            rec = block(rec)
        rec = torch.cat((rec, rgb_features[0]), 1)
        rec = self.depth_decoder_n(rec)

        output = (low_res_upsampled + rec) * (max_d - min_d) + min_d
        return {dataset_keys.PREDICTION_DEPTH_IMG: output}