in ignite/metrics/ssim.py [0:0]
def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output[0].detach(), output[1].detach()
if y_pred.dtype != y.dtype:
raise TypeError(
f"Expected y_pred and y to have the same data type. Got y_pred: {y_pred.dtype} and y: {y.dtype}."
)
if y_pred.shape != y.shape:
raise ValueError(
f"Expected y_pred and y to have the same shape. Got y_pred: {y_pred.shape} and y: {y.shape}."
)
if len(y_pred.shape) != 4 or len(y.shape) != 4:
raise ValueError(
f"Expected y_pred and y to have BxCxHxW shape. Got y_pred: {y_pred.shape} and y: {y.shape}."
)
channel = y_pred.size(1)
if len(self._kernel.shape) < 4:
self._kernel = self._kernel.expand(channel, 1, -1, -1).to(device=y_pred.device)
y_pred = F.pad(y_pred, [self.pad_w, self.pad_w, self.pad_h, self.pad_h], mode="reflect")
y = F.pad(y, [self.pad_w, self.pad_w, self.pad_h, self.pad_h], mode="reflect")
input_list = torch.cat([y_pred, y, y_pred * y_pred, y * y, y_pred * y])
outputs = F.conv2d(input_list, self._kernel, groups=channel)
output_list = [outputs[x * y_pred.size(0) : (x + 1) * y_pred.size(0)] for x in range(len(outputs))]
mu_pred_sq = output_list[0].pow(2)
mu_target_sq = output_list[1].pow(2)
mu_pred_target = output_list[0] * output_list[1]
sigma_pred_sq = output_list[2] - mu_pred_sq
sigma_target_sq = output_list[3] - mu_target_sq
sigma_pred_target = output_list[4] - mu_pred_target
a1 = 2 * mu_pred_target + self.c1
a2 = 2 * sigma_pred_target + self.c2
b1 = mu_pred_sq + mu_target_sq + self.c1
b2 = sigma_pred_sq + sigma_target_sq + self.c2
ssim_idx = (a1 * a2) / (b1 * b2)
self._sum_of_batchwise_ssim += torch.mean(ssim_idx, (1, 2, 3), dtype=torch.float64).to(self._device)
self._num_examples += y.shape[0]