in latent_operators.py [0:0]
def generate_translation_matrices(self, cardinals, z_dim):
"""Generates family of translation matrices"""
def DFT_matrix(cardinal, z_dim):
i, j = np.meshgrid(np.arange(cardinal), np.arange(cardinal))
omega = np.exp(2 * np.pi * 1j / cardinal)
W = np.power(omega, i * j)
return W
# Loop over all transformations that can happen to the sample
XYZ = []
for i, t in enumerate(cardinals):
K = self.cardinals[i]
X_i = np.arange(K)
if z_dim % K: # creates in shift operator an unfinished cycle
second_dim = (
int(np.floor(z_dim / K)) + 1
) # TODO: not sure this is the right way
else: # creates in shift operator a finished cycle
second_dim = int(z_dim / K)
X_i = np.tile(X_i.flatten(), (second_dim))[:z_dim]
XYZ.append(X_i)
_all_translation_matrices = list()
for i in range(len(cardinals)):
translation_matrices = DFT_matrix(cardinals[i], z_dim)
translation_matrices = translation_matrices[:, XYZ[i]]
translation_matrices_r = np.real(translation_matrices)
translation_matrices_i = np.imag(translation_matrices)
_translation_matrices_r = torch.tensor(
translation_matrices_r, dtype=torch.float32, device=self.device,
)
_translation_matrices_i = torch.tensor(
translation_matrices_i, dtype=torch.float32, device=self.device,
)
_all_translation_matrices.append(
(_translation_matrices_r, _translation_matrices_i,)
)
return _all_translation_matrices