in train.py [0:0]
def permute_arbitrarily(random_perms):
perms = [np.arange(dataset.ctx)] + random_perms
n = len(perms)
def fn(xx, x_emb, nprng=None, yy=None):
b, n_ctx = xx.shape
b_emb, n_emb, n_emb_ctx = x_emb.shape
assert b == b_emb
assert n_ctx == n_emb_ctx
if yy is None:
yy = sample_augmentation_type(n, size=(b, 1), nprng=nprng)
assert yy.shape[0] == xx.shape[0]
xx_new = []
x_emb_new = []
for i, y in enumerate(yy):
xx_new.append(xx[i][perms[y[0]]])
x_emb_new.append(x_emb[i][:, perms[y[0]]])
xx = np.concatenate(xx_new, axis=0).reshape(dataset.shape)
x_emb = np.concatenate(x_emb_new, axis=0)
x_emb = x_emb.reshape(b_emb, n_emb, n_emb_ctx)
return xx, x_emb, yy
return fn