in train.py [0:0]
def transpose(xx, x_emb, nprng=None, yy=None):
b = xx.shape[0]
b_emb, n_emb, n_ctx = x_emb.shape
assert b == b_emb
assert n_ctx == np.prod(dataset.orig_shape[1:])
if yy is None:
yy = sample_augmentation_type(2, size=(b, 1), nprng=nprng)
assert yy.shape[0] == xx.shape[0]
xx = xx.reshape(dataset.orig_shape)
xx = [np.transpose(x, [1, 0, 2]) if yy[i, 0] == 1 else x for i, x in enumerate(xx)]
xx = np.asarray(xx).reshape(dataset.shape)
x_emb = x_emb.reshape((b_emb, n_emb, *dataset.orig_shape[1:]))
x_emb = [np.transpose(x, [0, 2, 1, 3]) if yy[i, 0] == 1 else x for i, x in enumerate(x_emb)]
x_emb = np.asarray(x_emb).reshape((b_emb, n_emb, n_ctx))
return xx, x_emb, yy