in avhubert/hubert_asr.py [0:0]
def __init__(self, cfg: AVHubertAsrConfig, tgt_dict=None):
self.apply_mask = cfg.apply_mask
arg_overrides = {
"dropout": cfg.dropout,
"activation_dropout": cfg.activation_dropout,
"dropout_input": cfg.dropout_input,
"attention_dropout": cfg.attention_dropout,
"mask_length": cfg.mask_length,
"mask_prob": cfg.mask_prob,
"mask_selection": cfg.mask_selection,
"mask_other": cfg.mask_other,
"no_mask_overlap": cfg.no_mask_overlap,
"mask_channel_length": cfg.mask_channel_length,
"mask_channel_prob": cfg.mask_channel_prob,
"mask_channel_selection": cfg.mask_channel_selection,
"mask_channel_other": cfg.mask_channel_other,
"no_mask_channel_overlap": cfg.no_mask_channel_overlap,
"encoder_layerdrop": cfg.layerdrop,
"feature_grad_mult": cfg.feature_grad_mult,
}
if cfg.w2v_args is None:
state = checkpoint_utils.load_checkpoint_to_cpu(
cfg.w2v_path, arg_overrides
)
w2v_args = state.get("cfg", None)
if w2v_args is None:
w2v_args = convert_namespace_to_omegaconf(state["args"])
cfg.w2v_args = w2v_args
else:
state = None
w2v_args = cfg.w2v_args
if isinstance(w2v_args, Namespace):
cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(
w2v_args
)
assert cfg.normalize == w2v_args.task.normalize, (
"Fine-tuning works best when data normalization is the same. "
"Please check that --normalize is set or unset for "
"both pre-training and here"
)
w2v_args.task.data = cfg.data
task = tasks.setup_task(w2v_args.task)
model = task.build_model(w2v_args.model)
if state is not None and not cfg.no_pretrained_weights:
# set strict=False because we omit some modules
model.load_state_dict(state["model"], strict=False)
model.remove_pretraining_modules()
super().__init__(task.source_dictionary)
d = model.encoder.embedding_dim
self.w2v_model = model
self.final_dropout = nn.Dropout(cfg.final_dropout)
self.freeze_finetune_updates = cfg.freeze_finetune_updates
self.num_updates = 0
if tgt_dict is not None:
self.proj = Linear(d, len(tgt_dict))
elif getattr(cfg, "decoder_embed_dim", d) != d:
self.proj = Linear(d, cfg.decoder_embed_dim)
else:
self.proj = None