in latent_operators.py [0:0]
def translate_batch(self, z_batch, angles):
"""Applies shift operator to batch
Args:
angles (array of floats): counter-clockwise rotation in degrees.
"""
smallest_angle = 360 / (self.n_rotations + 1)
if angles.dim() > 1:
shifts = angles[:, 0] / smallest_angle
else:
shifts = angles / smallest_angle
try:
translated_batch = [
self.translate(z, shifts[i].long()) for i, z in enumerate(z_batch)
]
except IndexError as e:
print("===ANGLES ARE", angles)
raise e
return torch.stack(translated_batch)