in tools/convert_voxpopuli_models.py [0:0]
def _parse_model_param(cfg, state_dict):
key_mapping = {
"extractor_mode": "extractor_mode",
"conv_feature_layers": "extractor_conv_layer_config",
"conv_bias": "extractor_conv_bias",
"encoder_embed_dim": "encoder_embed_dim",
"dropout_input": "encoder_projection_dropout",
"conv_pos": "encoder_pos_conv_kernel",
"conv_pos_groups": "encoder_pos_conv_groups",
"encoder_layers": "encoder_num_layers",
"encoder_attention_heads": "encoder_num_heads",
"attention_dropout": "encoder_attention_dropout",
"encoder_ffn_embed_dim": "encoder_ff_interm_features",
"activation_dropout": "encoder_ff_interm_dropout",
"dropout": "encoder_dropout",
"layer_norm_first": "encoder_layer_norm_first",
"layerdrop": "encoder_layer_drop",
"encoder_layerdrop": "encoder_layer_drop",
}
params = {}
src_dicts = [cfg["model"]]
if "w2v_args" in cfg["model"]:
src_dicts.append(cfg["model"]["w2v_args"]["model"])
for src, tgt in key_mapping.items():
for model_cfg in src_dicts:
if src in model_cfg:
params[tgt] = model_cfg[src]
break
if params["extractor_mode"] == "default":
params["extractor_mode"] = "group_norm"
# the following line is commented out to resolve lint warning; uncomment before running script
# params["extractor_conv_layer_config"] = eval(params["extractor_conv_layer_config"])
assert len(params) == 15
params["aux_num_out"] = state_dict["proj.bias"].numel() if "proj.bias" in state_dict else None
return params