in complex_shift_autoencoder.py [0:0]
def return_shifts(self, params):
smallest_angle = 360 / (self.data.n_rotations + 1)
int_x = round(self.data.n_pixels / (self.data.n_x_translations + 1))
int_y = round(self.data.n_pixels / (self.data.n_y_translations + 1))
shifts_x = torch.LongTensor([[param.shift_x/int_x for param in params]]).t()
shifts_y = torch.LongTensor([[param.shift_y/int_y for param in params]]).t()
shifts_r = torch.LongTensor([[int(param.angle/smallest_angle) for param in params]]).t()
shifts = []
if self.data.n_rotations > 0:
shifts.append(shifts_r)
if self.data.n_x_translations > 0:
shifts.append(shifts_x)
if self.data.n_y_translations > 0:
shifts.append(shifts_y)
return shifts