def color_swap()

in train.py [0:0]


def color_swap(xx, x_emb, nprng=None, yy=None):
    b = xx.shape[0]
    b_emb, n_emb, n_ctx = x_emb.shape
    assert b == b_emb
    assert n_ctx == np.prod(dataset.orig_shape[1:])
    if yy is None:
        yy = sample_augmentation_type(6, size=(b, 1), nprng=nprng)
    assert yy.shape[0] == xx.shape[0]
    xx = xx.reshape(dataset.orig_shape)
    x_emb = x_emb.reshape((b, n_emb * dataset.orig_shape[1], *dataset.orig_shape[2:]))
    xx_new = []
    x_emb_new = []
    for i, order in enumerate(yy):
        xx_new.append(remap_c(xx[i], order[0]))
        x_emb_new.append(remap_c(x_emb[i], order[0]))
    xx = np.concatenate(xx_new, axis=0).reshape(dataset.shape)
    x_emb = np.concatenate(x_emb_new, axis=0).reshape((b_emb, n_emb, n_ctx))
    return xx, x_emb, yy