def create_gradient_mask()

in src/image_gen_aux/utils/tiling_utils.py [0:0]


def create_gradient_mask(shape: Tuple, feather: int, device="cpu") -> torch.Tensor:
    """
    Create a gradient mask for smooth blending of tiles.

    Args:
        shape (tuple): Shape of the mask (batch, channels, height, width)
        feather (int): Width of the feathered edge

    Returns:
        torch.Tensor: Gradient mask
    """
    mask = torch.ones(shape).to(device)
    _, _, h, w = shape
    for feather_step in range(feather):
        factor = (feather_step + 1) / feather
        mask[:, :, feather_step, :] *= factor
        mask[:, :, h - 1 - feather_step, :] *= factor
        mask[:, :, :, feather_step] *= factor
        mask[:, :, :, w - 1 - feather_step] *= factor
    return mask