def rotate()

in train.py [0:0]


def rotate(xx, x_emb, nprng=None, yy=None):
    b = xx.shape[0]
    b_emb, n_emb, n_ctx = x_emb.shape
    assert b == b_emb
    assert n_ctx == np.prod(dataset.orig_shape[1:])
    if yy is None:
        yy = sample_augmentation_type(4, size=(b, 1), nprng=nprng)
    assert yy.shape[0] == xx.shape[0]
    xx = xx.reshape(dataset.orig_shape)
    xx = [np.rot90(x, k=yy[i, 0], axes=(1, 0)) for i, x in enumerate(xx)]
    xx = np.asarray(xx).reshape(dataset.shape)
    x_emb = x_emb.reshape((b_emb, n_emb, *dataset.orig_shape[1:]))
    x_emb = [np.rot90(x, k=yy[i, 0], axes=(2, 1)) for i, x in enumerate(x_emb)]
    x_emb = np.asarray(x_emb).reshape((b_emb, n_emb, n_ctx))
    return xx, x_emb, yy