in train.py [0:0]
def jigsaw(xx, x_emb, nprng=None, yy=None):
b = xx.shape[0]
b_emb, n_emb, n_ctx = x_emb.shape
r, c, ch = dataset.orig_shape[1:]
assert b == b_emb
assert n_ctx == np.prod(dataset.orig_shape[1:])
xx = xx.reshape(dataset.orig_shape)
if yy is None:
yy = sample_augmentation_type(H.jigsaw_num_perms, size=(b, 1), nprng=nprng)
assert yy.shape[0] == xx.shape[0]
x_emb = x_emb.reshape(b, n_emb, r, c, ch)
x_emb = np.transpose(x_emb, [0, 2, 1, 3, 4])
x_emb = x_emb.reshape((b, n_emb * r, c, ch))
xx_new = []
x_emb_new = []
for i, order in enumerate(yy):
xx_new.append(remap_jigsaw(xx[i], order[0]))
x_emb_new.append(remap_jigsaw(x_emb[i], order[0]))
xx = np.concatenate(xx_new, axis=0).reshape(dataset.shape)
x_emb = np.concatenate(x_emb_new, axis=0)
x_emb = x_emb.reshape(b, r, n_emb, c, ch)
x_emb = np.transpose(x_emb, [0, 2, 1, 3, 4])
x_emb = x_emb.reshape((b_emb, n_emb, n_ctx))
return xx, x_emb, yy