in mask2former/maskformer_model.py [0:0]
def forward(self, batched_inputs):
"""
Args:
batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
Each item in the list contains the inputs for one image.
For now, each item in the list is a dict that contains:
* "image": Tensor, image in (C, H, W) format.
* "instances": per-region ground truth
* Other information that's included in the original dicts, such as:
"height", "width" (int): the output resolution of the model (may be different
from input resolution), used in inference.
Returns:
list[dict]:
each dict has the results for one image. The dict contains the following keys:
* "sem_seg":
A Tensor that represents the
per-pixel segmentation prediced by the head.
The prediction has shape KxHxW that represents the logits of
each class for each pixel.
* "panoptic_seg":
A tuple that represent panoptic output
panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
segments_info (list[dict]): Describe each segment in `panoptic_seg`.
Each dict contains keys "id", "category_id", "isthing".
"""
images = [x["image"].to(self.device) for x in batched_inputs]
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
images = ImageList.from_tensors(images, self.size_divisibility)
features = self.backbone(images.tensor)
outputs = self.sem_seg_head(features)
if self.training:
# mask classification target
if "instances" in batched_inputs[0]:
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
targets = self.prepare_targets(gt_instances, images)
else:
targets = None
# bipartite matching-based loss
losses = self.criterion(outputs, targets)
for k in list(losses.keys()):
if k in self.criterion.weight_dict:
losses[k] *= self.criterion.weight_dict[k]
else:
# remove this loss if not specified in `weight_dict`
losses.pop(k)
return losses
else:
mask_cls_results = outputs["pred_logits"]
mask_pred_results = outputs["pred_masks"]
# upsample masks
mask_pred_results = F.interpolate(
mask_pred_results,
size=(images.tensor.shape[-2], images.tensor.shape[-1]),
mode="bilinear",
align_corners=False,
)
del outputs
processed_results = []
for mask_cls_result, mask_pred_result, input_per_image, image_size in zip(
mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes
):
height = input_per_image.get("height", image_size[0])
width = input_per_image.get("width", image_size[1])
processed_results.append({})
if self.sem_seg_postprocess_before_inference:
mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
mask_pred_result, image_size, height, width
)
mask_cls_result = mask_cls_result.to(mask_pred_result)
# semantic segmentation inference
if self.semantic_on:
r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result)
if not self.sem_seg_postprocess_before_inference:
r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width)
processed_results[-1]["sem_seg"] = r
# panoptic segmentation inference
if self.panoptic_on:
panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result)
processed_results[-1]["panoptic_seg"] = panoptic_r
# instance segmentation inference
if self.instance_on:
instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result)
processed_results[-1]["instances"] = instance_r
return processed_results