def initialize_master_device_model_params()

in lib/utils/checkpoints_rel.py [0:0]


def initialize_master_device_model_params(model, weights_file):
    ws_blobs = workspace.Blobs()
    logger.info("Initializing model params from file: {}".format(weights_file))
    with open(weights_file, 'r') as fopen:
        blobs = pickle.load(fopen)
    if 'blobs' in blobs:
        blobs = blobs['blobs']
    unscoped_blob_names = OrderedDict()

    # Return the model iter from which training should start
    model_iter = 0
    if 'model_iter' in blobs:
        model_iter = blobs['model_iter']
    prev_lr = None
    if 'lr' in blobs:
        prev_lr = round(blobs['lr'], 6)

    # initialize params, params momentum, computed params
    if 'test' not in model.net.Name():
        for param in model.params:
            # Layers that are frozen during finetuning have no momentums
            scoped_blob_name = str(param) + '_momentum'
            if workspace.HasBlob(scoped_blob_name):
                unscoped_blob_names[helpers_rel.unscope_name(scoped_blob_name)] = True
    # NOTE: Currently GetAllParams() and GetAllParams('') both work. Use neat version
    for blob in model.GetAllParams():
        unscoped_blob_names[helpers_rel.unscope_name(str(blob))] = True

    root_device_id = cfg.ROOT_DEVICE_ID
    device = caffe2_pb2.CUDA if cfg.DEVICE == 'GPU' else caffe2_pb2.CPU
    with core.NameScope('gpu_{}'.format(root_device_id)):
        with core.DeviceScope(core.DeviceOption(device, root_device_id)):
            for unscoped_blob_name in unscoped_blob_names.keys():
                scoped_blob_name = helpers_rel.scoped_name(unscoped_blob_name)
                if unscoped_blob_name not in blobs:
                    logger.info('{:s} not found'.format(unscoped_blob_name))
                    continue
                if model.train:
                    logger.info('{:s} loaded from weights file into: {:s}'.format(
                        unscoped_blob_name, scoped_blob_name))
                if scoped_blob_name in ws_blobs:
                    ws_blob = workspace.FetchBlob(scoped_blob_name)
                    assert ws_blob.shape == blobs[unscoped_blob_name].shape, \
                        ('Workspace blob {} with shape {} does not match '
                         'weights file shape {}').format(
                            unscoped_blob_name, ws_blob.shape,
                            blobs[unscoped_blob_name].shape)
                data = blobs[unscoped_blob_name].astype(np.float32, copy=False)
                workspace.FeedBlob(scoped_blob_name, data)
    return model_iter, prev_lr