def jigsaw()

in train.py [0:0]


def jigsaw(xx, x_emb, nprng=None, yy=None):
    b = xx.shape[0]
    b_emb, n_emb, n_ctx = x_emb.shape
    r, c, ch = dataset.orig_shape[1:]
    assert b == b_emb
    assert n_ctx == np.prod(dataset.orig_shape[1:])
    xx = xx.reshape(dataset.orig_shape)
    if yy is None:
        yy = sample_augmentation_type(H.jigsaw_num_perms, size=(b, 1), nprng=nprng)
    assert yy.shape[0] == xx.shape[0]
    x_emb = x_emb.reshape(b, n_emb, r, c, ch)
    x_emb = np.transpose(x_emb, [0, 2, 1, 3, 4])
    x_emb = x_emb.reshape((b, n_emb * r, c, ch))
    xx_new = []
    x_emb_new = []
    for i, order in enumerate(yy):
        xx_new.append(remap_jigsaw(xx[i], order[0]))
        x_emb_new.append(remap_jigsaw(x_emb[i], order[0]))
    xx = np.concatenate(xx_new, axis=0).reshape(dataset.shape)
    x_emb = np.concatenate(x_emb_new, axis=0)
    x_emb = x_emb.reshape(b, r, n_emb, c, ch)
    x_emb = np.transpose(x_emb, [0, 2, 1, 3, 4])
    x_emb = x_emb.reshape((b_emb, n_emb, n_ctx))
    return xx, x_emb, yy