def reconstruction_mse_transformed_z1_weak()

in weakly_complex_shift_autoencoder.py [0:0]


    def reconstruction_mse_transformed_z1_weak(self, x1, x2, angles, use_argmax=False):
        """Computes reconstruction MSE of x1 from z1 + x2 from transformed(z1), not using ground-truth angles"""
        criterion = torch.nn.MSELoss(reduction="none")
        batch_size = x1.size(0)
        z1 = self.encoder(x1)
        z2 = self.encoder(x2)

        prod_size = np.prod(x1.size())
        x1_reconstruction_r = self.decoder(z1)
        x1_reconstruction_loss = criterion(x1_reconstruction_r, x1)
        x1_reconstruction_loss = x1_reconstruction_loss.mean()

        # TODO this is not adapted to product of shift operators, it's looking only at the 1st cardinal
        # Transform according to all possible angles, weighted
        angles_probas = self.compute_angles_probas(x1, x2, z1, z2)

        if use_argmax:
            predicted_angle = angles_probas.detach().argmax(
                -1, keepdims=True
            )
            z_transformed = self.transform(z1, predicted_angle)
            x2_reconstruction_r = self.decoder(z_transformed)
            x2_reconstruction_loss = criterion(x2_reconstruction_r, x2)
            x2_reconstruction_loss = x2_reconstruction_loss.mean()
        
        else:
            all_angles = torch.arange(self.K).repeat(1, batch_size).view(-1, 1)
            temp = self.temperature 
            mask = torch.softmax(angles_probas / temp, dim=-1)

            repeat_z1 = (
                z1[0][:, None, :].repeat(1, self.K, 1).view(batch_size * self.K, -1),
                z1[1][:, None, :].repeat(1, self.K, 1).view(batch_size * self.K, -1),
            )

            x2_repeat = (
                x2[:, None, ...]
                .repeat(1, self.K, 1, 1, 1)
                .view(batch_size * self.K, x2.size(1), x2.size(2), x2.size(3))
            )
            z_transformed = self.transform(repeat_z1, all_angles)
            x2_reconstruction_r = self.decoder(z_transformed)
            x2_reconstruction_transformed_loss = (
                criterion(x2_reconstruction_r, x2_repeat)
                .sum((1, 2, 3)) # sums over image dim
                .view(batch_size, -1)
            )
            x2_reconstruction_loss = (mask * x2_reconstruction_transformed_loss).sum() / prod_size

        loss = x1_reconstruction_loss + x2_reconstruction_loss
        return loss