def initialize_master_gpu_model_params()

in lib/utils/checkpoints.py [0:0]


def initialize_master_gpu_model_params(
        model, weights_file, load_momentum=True):
    ws_blobs = workspace.Blobs()
    logger.info("Initializing model params from file: {}".format(weights_file))
    with open(weights_file, 'r') as fopen:
        blobs = pickle.load(fopen)
    if 'blobs' in blobs:
        blobs = blobs['blobs']
    unscoped_blob_names = OrderedDict()

    # Return the model iter from which training should start
    model_iter = 0
    if 'model_iter' in blobs:
        model_iter = blobs['model_iter']

    if 'lr' in blobs:
        prev_lr = float(blobs['lr'])
    elif cfg.TRAIN.RESET_START_ITER:
        prev_lr = 1.
    else:
        raise Exception('No lr blob found.')

    # initialize params, params momentum, computed params
    if 'test' not in model.net.Name() and load_momentum:
        for param in model.params:
            if param in model.TrainableParams():
                unscoped_blob_names[misc.unscope_name(
                    str(param) + '_momentum')] = True
    for blob in model.GetAllParams():
        unscoped_blob_names[misc.unscope_name(str(blob))] = True

    root_gpu_id = cfg.ROOT_GPU_ID
    with core.NameScope('gpu_{}'.format(root_gpu_id)):
        with core.DeviceScope(core.DeviceOption(caffe2_pb2.CUDA, root_gpu_id)):
            for unscoped_blob_name in unscoped_blob_names.keys():
                scoped_blob_name = misc.scoped_name(unscoped_blob_name)
                if unscoped_blob_name not in blobs:
                    logger.info('{:s} not found'.format(unscoped_blob_name))
                    continue
                if scoped_blob_name in ws_blobs:
                    ws_blob = workspace.FetchBlob(scoped_blob_name)

                    if 'pred' in unscoped_blob_name:
                        if np.prod(ws_blob.shape) \
                                != np.prod(blobs[unscoped_blob_name].shape):
                            logger.info(('{:s} (classifier) found but ' +
                                            'unmatching (not loaded):' +
                                            '{} ---> {}')
                                        .format(
                                            unscoped_blob_name,
                                            blobs[unscoped_blob_name].shape,
                                            ws_blob.shape))
                            continue
                        else:
                            blobs[unscoped_blob_name] = np.reshape(
                                blobs[unscoped_blob_name], ws_blob.shape)

                    if len(ws_blob.shape) != \
                            len(blobs[unscoped_blob_name].shape):
                        # inflate if so
                        assert ws_blob.shape[:2] == \
                            blobs[unscoped_blob_name].shape[:2], \
                            ('Workspace blob {} with shape {} does not match '
                             'weights file shape {}').format(
                                unscoped_blob_name, ws_blob.shape,
                                blobs[unscoped_blob_name].shape)
                        assert ws_blob.shape[-2:] == \
                            blobs[unscoped_blob_name].shape[-2:], \
                            ('Workspace blob {} with shape {} does not match '
                             'weights file shape {}').format(
                                unscoped_blob_name, ws_blob.shape,
                                blobs[unscoped_blob_name].shape)

                        logger.info(
                            ('{:s} loaded from weights file into: {:s}' +
                                    ' inflated {} ---> {}').format(
                                unscoped_blob_name, scoped_blob_name,
                                blobs[unscoped_blob_name].shape,
                                ws_blob.shape))
                        # inflate
                        num_inflate = ws_blob.shape[2]
                        blobs[unscoped_blob_name] = np.stack(
                            [blobs[unscoped_blob_name]] * num_inflate,
                            axis=2) / float(num_inflate)
                    else:
                        logger.info(
                            ('{:s} loaded from weights file into: {:s}' +
                                    ' {}').format(
                                unscoped_blob_name, scoped_blob_name,
                                ws_blob.shape))

                    assert ws_blob.shape == blobs[unscoped_blob_name].shape, \
                        ('Workspace blob {} with shape {} does not match '
                         'weights file shape {}').format(
                            unscoped_blob_name, ws_blob.shape,
                            blobs[unscoped_blob_name].shape)
                data = blobs[unscoped_blob_name].astype(np.float32, copy=False)
                workspace.FeedBlob(scoped_blob_name, data)

    # hack fix: load and broadcast lr to all gpus
    for i in range(cfg.NUM_GPUS):
        with core.DeviceScope(core.DeviceOption(caffe2_pb2.CUDA, i)):
            workspace.FeedBlob(
                'gpu_{}/lr'.format(i), np.array(prev_lr, dtype=np.float32))

    return model_iter, prev_lr