in mask2former_video/modeling/transformer_decoder/video_mask2former_transformer_decoder.py [0:0]
def forward(self, x, mask_features, mask = None):
bt, c_m, h_m, w_m = mask_features.shape
bs = bt // self.num_frames if self.training else 1
t = bt // bs
mask_features = mask_features.view(bs, t, c_m, h_m, w_m)
# x is a list of multi-scale feature
assert len(x) == self.num_feature_levels
src = []
pos = []
size_list = []
# disable mask, it does not affect performance
del mask
for i in range(self.num_feature_levels):
size_list.append(x[i].shape[-2:])
pos.append(self.pe_layer(x[i].view(bs, t, -1, size_list[-1][0], size_list[-1][1]), None).flatten(3))
src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
# NTxCxHW => NxTxCxHW => (TxHW)xNxC
_, c, hw = src[-1].shape
pos[-1] = pos[-1].view(bs, t, c, hw).permute(1, 3, 0, 2).flatten(0, 1)
src[-1] = src[-1].view(bs, t, c, hw).permute(1, 3, 0, 2).flatten(0, 1)
# QxNxC
query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
predictions_class = []
predictions_mask = []
# prediction heads on learnable query features
outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0])
predictions_class.append(outputs_class)
predictions_mask.append(outputs_mask)
for i in range(self.num_layers):
level_index = i % self.num_feature_levels
attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
# attention: cross-attention first
output = self.transformer_cross_attention_layers[i](
output, src[level_index],
memory_mask=attn_mask,
memory_key_padding_mask=None, # here we do not apply masking on padded region
pos=pos[level_index], query_pos=query_embed
)
output = self.transformer_self_attention_layers[i](
output, tgt_mask=None,
tgt_key_padding_mask=None,
query_pos=query_embed
)
# FFN
output = self.transformer_ffn_layers[i](
output
)
outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels])
predictions_class.append(outputs_class)
predictions_mask.append(outputs_mask)
assert len(predictions_class) == self.num_layers + 1
out = {
'pred_logits': predictions_class[-1],
'pred_masks': predictions_mask[-1],
'aux_outputs': self._set_aux_loss(
predictions_class if self.mask_classification else None, predictions_mask
)
}
return out