def pad_packed_images()

in seamseg/utils/sequence.py [0:0]


def pad_packed_images(packed_images, pad_value=0., snap_size_to=None):
    """Assemble a padded tensor for a `PackedSequence` of images with different spatial sizes

    This method allows any standard convnet to operate on a `PackedSequence` of images as a batch

    Parameters
    ----------
    packed_images : PackedSequence
        A PackedSequence containing N tensors with different spatial sizes H_i, W_i. The tensors can be either 2D or 3D.
        If they are 3D, they must all have the same number of channels C.
    pad_value : float or int
        Value used to fill the padded areas
    snap_size_to : int or None
        If not None, chose the spatial sizes of the padded tensor to be multiples of this

    Returns
    -------
    padded_images : torch.Tensor
        A tensor with shape N x C x H x W or N x H x W, where `H = max_i H_i` and `W = max_i W_i` containing the images
        of the sequence aligned to the top left corner and padded with `pad_value`
    sizes : list of tuple of int
        A list with the original spatial sizes of the input images
    """
    if packed_images.all_none:
        raise ValueError("at least one image in packed_images should be non-None")

    reference_img = next(img for img in packed_images if img is not None)
    max_size = reference_img.shape[-2:]
    ndims = len(reference_img.shape)
    chn = reference_img.shape[0] if ndims == 3 else 0

    # Check the shapes and find maximum spatial size
    for img in packed_images:
        if img is not None:
            if len(img.shape) != 3 and len(img.shape) != 2:
                raise ValueError("The input sequence must contain 2D or 3D tensors")
            if len(img.shape) != ndims:
                raise ValueError("All tensors in the input sequence must have the same number of dimensions")
            if ndims == 3 and img.shape[0] != chn:
                raise ValueError("3D tensors must all have the same number of channels")
            max_size = [max(s1, s2) for s1, s2 in zip(max_size, img.shape[-2:])]

    # Optional size snapping
    if snap_size_to is not None:
        max_size = [(s + snap_size_to - 1) // snap_size_to * snap_size_to for s in max_size]

    if ndims == 3:
        padded_images = reference_img.new_full([len(packed_images), chn] + max_size, pad_value)
    else:
        padded_images = reference_img.new_full([len(packed_images)] + max_size, pad_value)

    sizes = []
    for i, tensor in enumerate(packed_images):
        if tensor is not None:
            if ndims == 3:
                padded_images[i, :, :tensor.shape[1], :tensor.shape[2]] = tensor
                sizes.append(tensor.shape[1:])
            else:
                padded_images[i, :tensor.shape[0], :tensor.shape[1]] = tensor
                sizes.append(tensor.shape)
        else:
            sizes.append((0, 0))

    return padded_images, sizes