in mask_former/modeling/matcher.py [0:0]
def memory_efficient_forward(self, outputs, targets):
"""More memory-friendly matching"""
bs, num_queries = outputs["pred_logits"].shape[:2]
# Work out the mask padding size
masks = [v["masks"] for v in targets]
h_max = max([m.shape[1] for m in masks])
w_max = max([m.shape[2] for m in masks])
indices = []
# Iterate through batch size
for b in range(bs):
out_prob = outputs["pred_logits"][b].softmax(-1) # [num_queries, num_classes]
out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred]
tgt_ids = targets[b]["labels"]
# gt masks are already padded when preparing target
tgt_mask = targets[b]["masks"].to(out_mask)
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
# but approximate it in 1 - proba[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
cost_class = -out_prob[:, tgt_ids]
# Downsample gt masks to save memory
tgt_mask = F.interpolate(tgt_mask[:, None], size=out_mask.shape[-2:], mode="nearest")
# Flatten spatial dimension
out_mask = out_mask.flatten(1) # [batch_size * num_queries, H*W]
tgt_mask = tgt_mask[:, 0].flatten(1) # [num_total_targets, H*W]
# Compute the focal loss between masks
cost_mask = batch_sigmoid_focal_loss(out_mask, tgt_mask)
# Compute the dice loss betwen masks
cost_dice = batch_dice_loss(out_mask, tgt_mask)
# Final cost matrix
C = (
self.cost_mask * cost_mask
+ self.cost_class * cost_class
+ self.cost_dice * cost_dice
)
C = C.reshape(num_queries, -1).cpu()
indices.append(linear_sum_assignment(C))
return [
(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
for i, j in indices
]