tools/convert_voxpopuli_models.py (71 lines of code) (raw):

#!/usr/bin/env python3 """Convert the fairseq models available in voxpopuli repo https://github.com/facebookresearch/voxpopuli The available checkpoints should open with fairseq. But the following error cannot be resolved with almost any version of fairseq. https://github.com/facebookresearch/voxpopuli/issues/29 So this script manually parse the checkpoint file and reconstruct the model. Examples ``` python convert_voxpopuli_models.py \ --input-file wav2vec2_base_10k_ft_fr.pt \ --output-file wav2vec2_voxpopuli_base_10k_asr_fr.pt ``` """ def _parse_args(): import argparse parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawTextHelpFormatter, ) parser.add_argument("--input-file", required=True, help="Input checkpoint file.") parser.add_argument("--output-file", required=False, help="Output model file.") return parser.parse_args() def _removeprefix(s, prefix): if s.startswith(prefix): return s[len(prefix) :] return s def _load(input_file): import torch from omegaconf import OmegaConf data = torch.load(input_file) cfg = OmegaConf.to_container(data["cfg"]) for key in list(cfg.keys()): if key != "model": del cfg[key] if "w2v_args" in cfg["model"]: del cfg["model"]["w2v_args"][key] state_dict = {_removeprefix(k, "w2v_encoder."): v for k, v in data["model"].items()} return cfg, state_dict def _parse_model_param(cfg, state_dict): key_mapping = { "extractor_mode": "extractor_mode", "conv_feature_layers": "extractor_conv_layer_config", "conv_bias": "extractor_conv_bias", "encoder_embed_dim": "encoder_embed_dim", "dropout_input": "encoder_projection_dropout", "conv_pos": "encoder_pos_conv_kernel", "conv_pos_groups": "encoder_pos_conv_groups", "encoder_layers": "encoder_num_layers", "encoder_attention_heads": "encoder_num_heads", "attention_dropout": "encoder_attention_dropout", "encoder_ffn_embed_dim": "encoder_ff_interm_features", "activation_dropout": "encoder_ff_interm_dropout", "dropout": "encoder_dropout", "layer_norm_first": "encoder_layer_norm_first", "layerdrop": "encoder_layer_drop", "encoder_layerdrop": "encoder_layer_drop", } params = {} src_dicts = [cfg["model"]] if "w2v_args" in cfg["model"]: src_dicts.append(cfg["model"]["w2v_args"]["model"]) for src, tgt in key_mapping.items(): for model_cfg in src_dicts: if src in model_cfg: params[tgt] = model_cfg[src] break if params["extractor_mode"] == "default": params["extractor_mode"] = "group_norm" # the following line is commented out to resolve lint warning; uncomment before running script # params["extractor_conv_layer_config"] = eval(params["extractor_conv_layer_config"]) assert len(params) == 15 params["aux_num_out"] = state_dict["proj.bias"].numel() if "proj.bias" in state_dict else None return params def _main(args): import json import torch import torchaudio from torchaudio.models.wav2vec2.utils.import_fairseq import _convert_state_dict as _convert cfg, state_dict = _load(args.input_file) params = _parse_model_param(cfg, state_dict) print(json.dumps(params, indent=4)) model = torchaudio.models.wav2vec2_model(**params) model.load_state_dict(_convert(state_dict)) torch.save(model.state_dict(), args.output_file) if __name__ == "__main__": _main(_parse_args())