def forward()

in activemri/experimental/cvpr19_models/models/reconstruction.py [0:0]


    def forward(self, zero_filled_input, mask):
        """Generates reconstructions given images with partial k-space info.

        Args:
            zero_filled_input(torch.Tensor): Image obtained from zero-filled reconstruction
                of partial k-space scans.
            mask(torch.Tensor): Mask used in creating the zero filled image from ground truth
                image.

        Returns:
            tuple(torch.Tensor, torch.Tensor, torch.Tensor): Contains:\n
                * Reconstructed high resolution image.
                * Uncertainty map.
                * Mask_embedding.
        """
        if self.use_mask_embedding:
            mask_embedding = self.embed_mask(mask)
            mask_embedding = mask_embedding.repeat(
                1, 1, zero_filled_input.shape[2], zero_filled_input.shape[3]
            )
            encoder_input = torch.cat([zero_filled_input, mask_embedding], 1)
        else:
            encoder_input = zero_filled_input
            mask_embedding = None

        residual_bottleneck_output = None
        for cascade_block, (encoder, residual_bottleneck, decoder) in enumerate(
            zip(
                self.encoders_all_cascade_blocks,
                self.residual_bottlenecks_all_cascade_blocks,
                self.decoders_all_cascade_blocks,
            )
        ):
            encoder_output = encoder(encoder_input)
            if cascade_block > 0:
                # Skip connection from previous residual block
                encoder_output = encoder_output + residual_bottleneck_output

            residual_bottleneck_output = residual_bottleneck(encoder_output)

            decoder_output = decoder(residual_bottleneck_output)

            reconstructed_image = self.data_consistency(
                decoder_output[:, :-1, ...], zero_filled_input, mask
            )
            uncertainty_map = decoder_output[:, -1:, :, :]

            if self.use_mask_embedding:
                encoder_input = torch.cat([reconstructed_image, mask_embedding], 1)
            else:
                encoder_input = reconstructed_image

        return reconstructed_image, uncertainty_map, mask_embedding