def convert()

in convert/convert_from_mxnet.py [0:0]


def convert(mxnet_name, torch_name):
    # download and load the pre-trained model
    net = gluoncv.model_zoo.get_model(mxnet_name, pretrained=True)

    # create corresponding torch model
    torch_net = create_model(torch_name)

    mxp = [(k, v) for k, v in net.collect_params().items() if 'running' not in k]
    torchp = list(torch_net.named_parameters())
    torch_params = {}

    # convert parameters
    # NOTE: we are relying on the fact that the order of parameters
    # are usually exactly the same between these models, thus no key name mapping
    # is necessary. Asserts will trip if this is not the case.
    for (tn, tv), (mn, mv) in zip(torchp, mxp):
        m_split = mn.split('_')
        t_split = tn.split('.')
        print(t_split, m_split)
        print(tv.shape, mv.shape)

        # ensure ordering of BN params match since their sizes are not specific
        if m_split[-1] == 'gamma':
            assert t_split[-1] == 'weight'
        if m_split[-1] == 'beta':
            assert t_split[-1] == 'bias'

        # ensure shapes match
        assert all(t == m for t, m in zip(tv.shape, mv.shape))

        torch_tensor = torch.from_numpy(mv.data().asnumpy())
        torch_params[tn] = torch_tensor

    # convert buffers (batch norm running stats)
    mxb = [(k, v) for k, v in net.collect_params().items() if any(x in k for x in ['running_mean', 'running_var'])]
    torchb = [(k, v) for k, v in torch_net.named_buffers() if 'num_batches' not in k]
    for (tn, tv), (mn, mv) in zip(torchb, mxb):
        print(tn, mn)
        print(tv.shape, mv.shape)

        # ensure ordering of BN params match since their sizes are not specific
        if 'running_var' in tn:
            assert 'running_var' in mn
        if 'running_mean' in tn:
            assert 'running_mean' in mn
            
        torch_tensor = torch.from_numpy(mv.data().asnumpy())
        torch_params[tn] = torch_tensor

    torch_net.load_state_dict(torch_params)
    torch_filename = './%s.pth' % torch_name
    torch.save(torch_net.state_dict(), torch_filename)
    with open(torch_filename, 'rb') as f:
        sha_hash = hashlib.sha256(f.read()).hexdigest()
    final_filename = os.path.splitext(torch_filename)[0] + '-' + sha_hash[:8] + '.pth'
    os.rename(torch_filename, final_filename)
    print("=> Saved converted model to '{}, SHA256: {}'".format(final_filename, sha_hash))