in augerino_lib/uniform_aug.py [0:0]
def generate(self, weights):
"""
return the sum of the scaled generator matrices
"""
bs = weights.shape[0]
if self.g0 is None or self.std_batch_size != bs:
self.std_batch_size = bs
## tx
self.g0 = torch.zeros(3, 3, device=weights.device)
self.g0[0, 2] = 1. * self.trans_scale
self.g0 = self.g0.unsqueeze(-1).expand(3,3, bs)
## ty
self.g1 = torch.zeros(3, 3, device=weights.device)
self.g1[1, 2] = 1. * self.trans_scale
self.g1 = self.g1.unsqueeze(-1).expand(3,3, bs)
self.g2 = torch.zeros(3, 3, device=weights.device)
self.g2[0, 1] = -1.
self.g2[1, 0] = 1.
self.g2 = self.g2.unsqueeze(-1).expand(3,3, bs)
self.g3 = torch.zeros(3, 3, device=weights.device)
self.g3[0, 0] = 1.
self.g3[1, 1] = 1.
self.g3 = self.g3.unsqueeze(-1).expand(3,3, bs)
self.g4 = torch.zeros(3, 3, device=weights.device)
self.g4[0, 0] = 1.
self.g4[1, 1] = -1.
self.g4 = self.g4.unsqueeze(-1).expand(3,3, bs)
self.g5 = torch.zeros(3, 3, device=weights.device)
self.g5[0, 1] = 1.
self.g5[1, 0] = 1.
self.g5 = self.g5.unsqueeze(-1).expand(3,3, bs)
out_mat = weights[:, 0] * self.g0
out_mat += weights[:, 1] * self.g1
out_mat += weights[:, 2] * self.g2
out_mat += weights[:, 3] * self.g3
out_mat += weights[:, 4] * self.g4
out_mat += weights[:, 5] * self.g5
# transposes just to get everything right
return out_mat.transpose(0, 2).transpose(2, 1)