in train.py [0:0]
def color_swap(xx, x_emb, nprng=None, yy=None):
b = xx.shape[0]
b_emb, n_emb, n_ctx = x_emb.shape
assert b == b_emb
assert n_ctx == np.prod(dataset.orig_shape[1:])
if yy is None:
yy = sample_augmentation_type(6, size=(b, 1), nprng=nprng)
assert yy.shape[0] == xx.shape[0]
xx = xx.reshape(dataset.orig_shape)
x_emb = x_emb.reshape((b, n_emb * dataset.orig_shape[1], *dataset.orig_shape[2:]))
xx_new = []
x_emb_new = []
for i, order in enumerate(yy):
xx_new.append(remap_c(xx[i], order[0]))
x_emb_new.append(remap_c(x_emb[i], order[0]))
xx = np.concatenate(xx_new, axis=0).reshape(dataset.shape)
x_emb = np.concatenate(x_emb_new, axis=0).reshape((b_emb, n_emb, n_ctx))
return xx, x_emb, yy