in src/modules/transformer_decoder.py [0:0]
def forward(self, features, mask, captions, incremental_state=None):
if features is not None:
features = features.permute(0, 2, 1)
features = features.transpose(0, 1)
features = self.layer_norm(features)
if mask is not None:
mask = (1 - mask.squeeze(1)).byte()
# embed positions
if self.embed_positions is not None:
positions = self.embed_positions(captions, incremental_state=incremental_state)
if incremental_state is not None:
if self.embed_positions is not None:
positions = positions[:, -1:]
captions = captions[:, -1:]
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(captions)
if self.embed_positions is not None:
x += positions
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# decoder layers
for layer in self.layers:
x = layer(x, features, mask, incremental_state)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
x = self.linear(x)
_, predicted = x.max(dim=-1)
if incremental_state is None:
return x, predicted
else:
return x