in avhubert/hubert_asr.py [0:0]
def build_model(cls, cfg, task):
"""Build a new model instance."""
arg_overrides = {
"dropout": cfg.dropout,
"activation_dropout": cfg.activation_dropout,
"dropout_input": cfg.dropout_input,
"attention_dropout": cfg.attention_dropout,
"mask_length": cfg.mask_length,
"mask_prob": cfg.mask_prob,
"mask_selection": cfg.mask_selection,
"mask_other": cfg.mask_other,
"no_mask_overlap": cfg.no_mask_overlap,
"mask_channel_length": cfg.mask_channel_length,
"mask_channel_prob": cfg.mask_channel_prob,
"mask_channel_selection": cfg.mask_channel_selection,
"mask_channel_other": cfg.mask_channel_other,
"no_mask_channel_overlap": cfg.no_mask_channel_overlap,
"encoder_layerdrop": cfg.layerdrop,
"feature_grad_mult": cfg.feature_grad_mult,
}
if cfg.w2v_args is None:
state = checkpoint_utils.load_checkpoint_to_cpu(
cfg.w2v_path, arg_overrides
)
w2v_args = state.get("cfg", None)
if w2v_args is None:
w2v_args = convert_namespace_to_omegaconf(state["args"])
cfg.w2v_args = w2v_args
else:
state = None
w2v_args = cfg.w2v_args
if isinstance(w2v_args, Namespace):
cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(
w2v_args
)
assert cfg.normalize == w2v_args.task.normalize, (
"Fine-tuning works best when data normalization is the same. "
"Please check that --normalize is set or unset for "
"both pre-training and here"
)
w2v_args.task.data = cfg.data
task_pretrain = tasks.setup_task(w2v_args.task)
if state is not None:
task_pretrain.load_state_dict(state['task_state'])
encoder_ = task_pretrain.build_model(w2v_args.model)
encoder = HubertEncoderWrapper(encoder_)
if state is not None and not cfg.no_pretrained_weights:
# set strict=False because we omit some modules
del state['model']['mask_emb']
encoder.w2v_model.load_state_dict(state["model"], strict=False)
encoder.w2v_model.remove_pretraining_modules()
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
def build_embedding(dictionary, embed_dim):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
emb = Embedding(num_embeddings, embed_dim, padding_idx=padding_idx)
return emb
decoder_embed_tokens = build_embedding(tgt_dict, cfg.decoder_embed_dim)
decoder = TransformerDecoder(cfg, tgt_dict, decoder_embed_tokens)
return AVHubertSeq2Seq(encoder, decoder, tgt_dict, cfg)