def permute_arbitrarily()

in train.py [0:0]


def permute_arbitrarily(random_perms):
    perms = [np.arange(dataset.ctx)] + random_perms
    n = len(perms)

    def fn(xx, x_emb, nprng=None, yy=None):
        b, n_ctx = xx.shape
        b_emb, n_emb, n_emb_ctx = x_emb.shape
        assert b == b_emb
        assert n_ctx == n_emb_ctx
        if yy is None:
            yy = sample_augmentation_type(n, size=(b, 1), nprng=nprng)
        assert yy.shape[0] == xx.shape[0]
        xx_new = []
        x_emb_new = []
        for i, y in enumerate(yy):
            xx_new.append(xx[i][perms[y[0]]])
            x_emb_new.append(x_emb[i][:, perms[y[0]]])
        xx = np.concatenate(xx_new, axis=0).reshape(dataset.shape)
        x_emb = np.concatenate(x_emb_new, axis=0)
        x_emb = x_emb.reshape(b_emb, n_emb, n_emb_ctx)
        return xx, x_emb, yy
    return fn