def load_state_dict()

in lib/net_util.py [0:0]


def load_state_dict(state_dict, net):
    model_dict = net.state_dict()

    pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict}                    

    for k, v in pretrained_dict.items():                      
        if v.size() == model_dict[k].size():
            model_dict[k] = v

    not_initialized = set()
            
    for k, v in model_dict.items():
        if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
            not_initialized.add(k.split('.')[0])
    
    print('not initialized', sorted(not_initialized))
    net.load_state_dict(model_dict) 

    return net