in avhubert/hubert.py [0:0]
def extract_finetune(self, source, padding_mask=None, mask=False, ret_conv=False, output_layer=None):
src_audio, src_video = source['audio'], source['video']
if mask and self.masking_type == 'input':
src_video, mask_indices_video = self.apply_input_mask(src_video, padding_mask, target_list=None)
src_audio, mask_indices_audio = self.apply_input_mask(src_audio, padding_mask, target_list=None)
mask_indices = torch.logical_or(mask_indices_audio, mask_indices_video) # mask_indices not used in fine-tuning
else:
src_audio, src_video, mask_indices = src_audio, src_video, None
if src_audio is not None and src_video is None:
features_audio = self.forward_features(src_audio, modality='audio') # features: [B, F, T]
features_video = features_audio.new_zeros(features_audio.size(0), self.encoder_embed_dim, features_audio.size(-1))
elif src_audio is None and src_video is not None:
features_video = self.forward_features(src_video, modality='video')
features_audio = features_video.new_zeros(features_video.size(0), self.encoder_embed_dim, features_video.size(-1))
elif src_audio is not None and src_video is not None:
features_video = self.forward_features(src_video, modality='video')
features_audio = self.forward_features(src_audio, modality='audio') # features: [B, F, T]
if self.modality_fuse == 'concat':
features = torch.cat([features_audio, features_video], dim=1)
elif self.modality_fuse == 'add':
features = features_audio + features_video
features_pen = features.float().pow(2).mean()
features = features.transpose(1, 2)
features = self.layer_norm(features)
unmasked_features = features.clone()
if padding_mask is not None:
padding_mask = self.forward_padding_mask(features, padding_mask)
if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
features = self.dropout_input(features)
unmasked_features = self.dropout_features(unmasked_features)
x = features
mask_indices = None
# feature: (B, T, D), float
# target: (B, T), long
# x: (B, T, D), float
# padding_mask: (B, T), bool
# mask_indices: (B, T), bool
x, _ = self.encoder(
x,
padding_mask=padding_mask,
layer=None if output_layer is None else output_layer - 1
)
return x, padding_mask