def transform()

in datasets/transformations.py [0:0]


def transform(image, params):
    """
    Applies transformations on a single image based on params.
    Order of transformation is: rotate, translate, scale

    Args:
        image (np.array or torch.tensor): of shape [n_pixels, n_pixels]
        params (Params): contains parameters for rotations, scaling etc.

    Returns: image with transformations applied
    """
    assert (
        image.ndim == 3
    ), f"image must be of shape [n_channels, n_pixels, n_pixels] not {image.shape}"

    image_transformed = image.squeeze()
    # Rotate
    if params.angle not in (0.0, 360.0):
        # cval is the fill value.
        image_transformed = skimage.transform.rotate(
            image_transformed, params.angle, cval=image_transformed.min()
        )

    # Translate
    # if edge is reached cut-off portion appears on other side
    if params.shift_x != 0.0:
        image_transformed = np.roll(image_transformed, int(params.shift_x), axis=1)
    if params.shift_y != 0.0:
        image_transformed = np.roll(image_transformed, -int(params.shift_y), axis=0)

    # Scale
    if params.scale != 1.0:
        image_transformed = rescale(image_transformed, params.scale)
    image_transformed = to_torch(image, image_transformed)
    return image_transformed