def init_model()

in func/train.py [0:0]


def init_model(model, ckpt_path, modules_to_keep, logger):
    """Initialize model with weights from ckpt_path.
    Args:
        ckpt_path (str): A string with path to file
        modules_to_keep (str): A comma sep string with the module name prefix
            that should be loaded from the checkpoint
    """
    logger.debug('Initing %s with ckpt path: %s, using modules in it %s',
                 model, ckpt_path, modules_to_keep)
    checkpoint = torch.load(ckpt_path, map_location="cpu")
    if 'model' in checkpoint.keys():
        state_dict = checkpoint['model']
    elif 'state_dict' in checkpoint.keys():
        state_dict = checkpoint['state_dict']
    elif 'classy_state_dict' in checkpoint.keys():
        state_dict = checkpoint['classy_state_dict']
        # This is likely coming from a VISSL codebase, so the actual trunk
        # params will be as follows. Ideally support this more generally TODO
        state_dict = state_dict['base_model']['model']['trunk']
    else:
        state_dict = checkpoint
    if modules_to_keep:
        # Keep only the elements of state_dict that match modules to keep.
        # Also, remove that prefix from the names
        filtered_state_dict = {}
        for key, val in state_dict.items():
            for mod_name in modules_to_keep.split(','):
                if key.startswith(mod_name):
                    filtered_state_dict[key[len(mod_name):]] = val
                    continue
        state_dict = filtered_state_dict
    # Ignore any parameters/buffers (bn mean/var) where shape does not match
    for name, param in itertools.chain(model.named_parameters(),
                                       model.named_buffers()):
        if name in state_dict and state_dict[name].shape != param.shape:
            logger.warning('Ckpt shape mismatch for %s (%s vs %s). Ignoring.',
                           name, state_dict[name].shape, param.shape)
            del state_dict[name]
    missing_keys, unexp_keys = model.load_state_dict(state_dict, strict=False)
    logger.warning('Could not init from %s: %s', ckpt_path, missing_keys)
    logger.warning('Unused keys in %s: %s', ckpt_path, unexp_keys)