in weakly_complex_shift_autoencoder.py [0:0]
def reconstruction_mse_transformed_z1_weak(self, x1, x2, angles, use_argmax=False):
"""Computes reconstruction MSE of x1 from z1 + x2 from transformed(z1), not using ground-truth angles"""
criterion = torch.nn.MSELoss(reduction="none")
batch_size = x1.size(0)
z1 = self.encoder(x1)
z2 = self.encoder(x2)
prod_size = np.prod(x1.size())
x1_reconstruction_r = self.decoder(z1)
x1_reconstruction_loss = criterion(x1_reconstruction_r, x1)
x1_reconstruction_loss = x1_reconstruction_loss.mean()
# TODO this is not adapted to product of shift operators, it's looking only at the 1st cardinal
# Transform according to all possible angles, weighted
angles_probas = self.compute_angles_probas(x1, x2, z1, z2)
if use_argmax:
predicted_angle = angles_probas.detach().argmax(
-1, keepdims=True
)
z_transformed = self.transform(z1, predicted_angle)
x2_reconstruction_r = self.decoder(z_transformed)
x2_reconstruction_loss = criterion(x2_reconstruction_r, x2)
x2_reconstruction_loss = x2_reconstruction_loss.mean()
else:
all_angles = torch.arange(self.K).repeat(1, batch_size).view(-1, 1)
temp = self.temperature
mask = torch.softmax(angles_probas / temp, dim=-1)
repeat_z1 = (
z1[0][:, None, :].repeat(1, self.K, 1).view(batch_size * self.K, -1),
z1[1][:, None, :].repeat(1, self.K, 1).view(batch_size * self.K, -1),
)
x2_repeat = (
x2[:, None, ...]
.repeat(1, self.K, 1, 1, 1)
.view(batch_size * self.K, x2.size(1), x2.size(2), x2.size(3))
)
z_transformed = self.transform(repeat_z1, all_angles)
x2_reconstruction_r = self.decoder(z_transformed)
x2_reconstruction_transformed_loss = (
criterion(x2_reconstruction_r, x2_repeat)
.sum((1, 2, 3)) # sums over image dim
.view(batch_size, -1)
)
x2_reconstruction_loss = (mask * x2_reconstruction_transformed_loss).sum() / prod_size
loss = x1_reconstruction_loss + x2_reconstruction_loss
return loss