scripts/utility/convert_pytorch_resnet.py (58 lines of code) (raw):

# Copyright (c) Facebook, Inc. and its affiliates. import argparse import torch import torch.utils.model_zoo as model_zoo MODEL_URLS = { 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', } NETS = { "resnet18": {"structure": [2, 2, 2, 2], "bottleneck": False}, "resnet34": {"structure": [3, 4, 6, 3], "bottleneck": False}, "resnet50": {"structure": [3, 4, 6, 3], "bottleneck": True}, "resnet101": {"structure": [3, 4, 23, 3], "bottleneck": True}, "resnet152": {"structure": [3, 8, 36, 3], "bottleneck": True}, } CONV_PARAMS = ["weight"] BN_PARAMS = ["weight", "bias", "running_mean", "running_var"] parser = argparse.ArgumentParser(description="Convert pre-trained ResNet from Pytorch to our format") parser.add_argument("net", metavar="NET", type=str, choices=list(MODEL_URLS.keys()), help="Network name") parser.add_argument("out_file", metavar="OUT", type=str, help="Output file") def copy_layer(inm, outm, name_in, name_out, params): for param_name in params: outm[name_out + "." + param_name] = inm[name_in + "." + param_name] def convert(model, structure, bottleneck): out = dict() num_convs = 3 if bottleneck else 2 # Initial module copy_layer(model, out, "conv1", "mod1.conv1", CONV_PARAMS) copy_layer(model, out, "bn1", "mod1.bn1", BN_PARAMS) # Other modules for mod_id, num in enumerate(structure): for block_id in range(num): for conv_id in range(num_convs): copy_layer(model, out, "layer{}.{}.conv{}".format(mod_id + 1, block_id, conv_id + 1), "mod{}.block{}.convs.conv{}".format(mod_id + 2, block_id + 1, conv_id + 1), CONV_PARAMS) copy_layer(model, out, "layer{}.{}.bn{}".format(mod_id + 1, block_id, conv_id + 1), "mod{}.block{}.convs.bn{}".format(mod_id + 2, block_id + 1, conv_id + 1), BN_PARAMS) # Try copying projection module try: copy_layer(model, out, "layer{}.{}.downsample.0".format(mod_id + 1, block_id), "mod{}.block{}.proj_conv".format(mod_id + 2, block_id + 1), CONV_PARAMS) copy_layer(model, out, "layer{}.{}.downsample.1".format(mod_id + 1, block_id), "mod{}.block{}.proj_bn".format(mod_id + 2, block_id + 1), BN_PARAMS) except KeyError: pass return out if __name__ == '__main__': args = parser.parse_args() original = model_zoo.load_url(MODEL_URLS[args.net]) converted = convert(original, **NETS[args.net]) torch.save(converted, args.out_file)