in maskrcnn_benchmark/structures/segmentation_mask.py [0:0]
def __init__(self, masks, size):
"""
Arguments:
masks: Either torch.tensor of [num_instances, H, W]
or list of torch.tensors of [H, W] with num_instances elems,
or RLE (Run Length Encoding) - interpreted as list of dicts,
or BinaryMaskList.
size: absolute image size, width first
After initialization, a hard copy will be made, to leave the
initializing source data intact.
"""
assert isinstance(size, (list, tuple))
assert len(size) == 2
if isinstance(masks, torch.Tensor):
# The raw data representation is passed as argument
masks = masks.clone()
elif isinstance(masks, (list, tuple)):
if len(masks) == 0:
masks = torch.empty([0, size[1], size[0]]) # num_instances = 0!
elif isinstance(masks[0], torch.Tensor):
masks = torch.stack(masks, dim=0).clone()
elif isinstance(masks[0], dict) and "counts" in masks[0]:
if(isinstance(masks[0]["counts"], (list, tuple))):
masks = mask_utils.frPyObjects(masks, size[1], size[0])
# RLE interpretation
rle_sizes = [tuple(inst["size"]) for inst in masks]
masks = mask_utils.decode(masks) # [h, w, n]
masks = torch.tensor(masks).permute(2, 0, 1) # [n, h, w]
assert rle_sizes.count(rle_sizes[0]) == len(rle_sizes), (
"All the sizes must be the same size: %s" % rle_sizes
)
# in RLE, height come first in "size"
rle_height, rle_width = rle_sizes[0]
assert masks.shape[1] == rle_height
assert masks.shape[2] == rle_width
width, height = size
if width != rle_width or height != rle_height:
masks = interpolate(
input=masks[None].float(),
size=(height, width),
mode="bilinear",
align_corners=False,
)[0].type_as(masks)
else:
RuntimeError(
"Type of `masks[0]` could not be interpreted: %s"
% type(masks)
)
elif isinstance(masks, BinaryMaskList):
# just hard copy the BinaryMaskList instance's underlying data
masks = masks.masks.clone()
else:
RuntimeError(
"Type of `masks` argument could not be interpreted:%s"
% type(masks)
)
if len(masks.shape) == 2:
# if only a single instance mask is passed
masks = masks[None]
assert len(masks.shape) == 3
assert masks.shape[1] == size[1], "%s != %s" % (masks.shape[1], size[1])
assert masks.shape[2] == size[0], "%s != %s" % (masks.shape[2], size[0])
self.masks = masks
self.size = tuple(size)