in mask2former/modeling/transformer_decoder/mask2former_transformer_decoder.py [0:0]
def forward(self, x, mask_features, mask = None):
# x is a list of multi-scale feature
assert len(x) == self.num_feature_levels
src = []
pos = []
size_list = []
# disable mask, it does not affect performance
del mask
for i in range(self.num_feature_levels):
size_list.append(x[i].shape[-2:])
pos.append(self.pe_layer(x[i], None).flatten(2))
src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
# flatten NxCxHxW to HWxNxC
pos[-1] = pos[-1].permute(2, 0, 1)
src[-1] = src[-1].permute(2, 0, 1)
_, bs, _ = src[0].shape
# QxNxC
query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
predictions_class = []
predictions_mask = []
# prediction heads on learnable query features
outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0])
predictions_class.append(outputs_class)
predictions_mask.append(outputs_mask)
for i in range(self.num_layers):
level_index = i % self.num_feature_levels
attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
# attention: cross-attention first
output = self.transformer_cross_attention_layers[i](
output, src[level_index],
memory_mask=attn_mask,
memory_key_padding_mask=None, # here we do not apply masking on padded region
pos=pos[level_index], query_pos=query_embed
)
output = self.transformer_self_attention_layers[i](
output, tgt_mask=None,
tgt_key_padding_mask=None,
query_pos=query_embed
)
# FFN
output = self.transformer_ffn_layers[i](
output
)
outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels])
predictions_class.append(outputs_class)
predictions_mask.append(outputs_mask)
assert len(predictions_class) == self.num_layers + 1
out = {
'pred_logits': predictions_class[-1],
'pred_masks': predictions_mask[-1],
'aux_outputs': self._set_aux_loss(
predictions_class if self.mask_classification else None, predictions_mask
)
}
return out