def restore()

in easy_rec/python/model/easy_rec_model.py [0:0]


  def restore(self,
              ckpt_path,
              include_global_step=False,
              ckpt_var_map_path='',
              force_restore_shape_compatible=False):
    """Restore variables from ckpt_path.

    steps:
      1. list the variables in graph that need to be restored
      2. inspect checkpoint and find the variables that could restore from checkpoint
         substitute scope names in case necessary
      3. call tf.train.init_from_checkpoint to restore the variables

    Args:
       ckpt_path: checkpoint path to restore from
       include_global_step: whether to restore global_step variable
       ckpt_var_map_path: variable map from graph variables to variables in a checkpoint
          each line consists of: variable name in graph  variable name in ckpt
       force_restore_shape_compatible: if variable shape is incompatible, clip or pad
          variables in checkpoint, and then restore

    Returns:
      IncompatibleShapeRestoreHook if force_shape_compatible else None
    """
    name2var_map = self._get_restore_vars(ckpt_var_map_path)
    logging.info('start to restore from %s' % ckpt_path)

    ckpt_reader = tf.train.NewCheckpointReader(ckpt_path)
    ckpt_var2shape_map = ckpt_reader.get_variable_to_shape_map()
    if not include_global_step:
      ckpt_var2shape_map.pop(tf.GraphKeys.GLOBAL_STEP, None)

    vars_in_ckpt = {}
    incompatible_shape_var_map = {}
    fail_restore_vars = []
    for variable_name, variable in sorted(name2var_map.items()):
      if variable_name in ckpt_var2shape_map:
        print('restore %s' % variable_name)
        ckpt_var_shape = ckpt_var2shape_map[variable_name]
        if type(variable) == list:
          shape_arr = [x.get_shape() for x in variable]
          var_shape = list(shape_arr[0])
          for x in shape_arr[1:]:
            var_shape[0] += x[0]
          var_shape = tensor_shape.TensorShape(var_shape)
          variable = variables.PartitionedVariable(
              variable_name,
              var_shape,
              variable[0].dtype,
              variable,
              partitions=[len(variable)] + [1] * (len(var_shape) - 1))
        else:
          var_shape = variable.shape.as_list()
        if ckpt_var_shape == var_shape:
          vars_in_ckpt[variable_name] = list(variable) if isinstance(
              variable, variables.PartitionedVariable) else variable
        elif len(ckpt_var_shape) == len(var_shape):
          if force_restore_shape_compatible:
            # create a variable compatible with checkpoint to restore
            dtype = variable[0].dtype if isinstance(variable,
                                                    list) else variable.dtype
            with tf.variable_scope('incompatible_shape_restore'):
              tmp_var = tf.get_variable(
                  name=variable_name + '_T_E_M_P',
                  shape=ckpt_var_shape,
                  trainable=False,
                  # add to a special collection for easy reference
                  # by tf.get_collection('T_E_M_P_RESTROE')
                  collections=['T_E_M_P_RESTROE'],
                  dtype=dtype)
            vars_in_ckpt[variable_name] = tmp_var
            incompatible_shape_var_map[variable] = tmp_var
            print('incompatible restore %s[%s, %s]' %
                  (variable_name, str(var_shape), str(ckpt_var_shape)))
          else:
            logging.warning(
                'Variable [%s] is available in checkpoint, but '
                'incompatible shape with model variable.', variable_name)
        else:
          logging.warning(
              'Variable [%s] is available in checkpoint, but '
              'incompatible shape dims with model variable.', variable_name)
      elif 'EmbeddingVariable' in str(type(variable)):
        if '%s-keys' % variable_name not in ckpt_var2shape_map:
          continue
        print('restore embedding_variable %s' % variable_name)
        from tensorflow.python.training import saver
        names_to_saveables = saver.BaseSaverBuilder.OpListToDict([variable])
        saveable_objects = []
        for name, op in names_to_saveables.items():
          for s in saver.BaseSaverBuilder.SaveableObjectsForOp(op, name):
            saveable_objects.append(s)
        init_op = saveable_objects[0].restore([ckpt_path], None)
        variable._initializer_op = init_op
      elif type(variable) == list and 'EmbeddingVariable' in str(
          type(variable[0])):
        if '%s/part_0-keys' % variable_name not in ckpt_var2shape_map:
          continue
        print('restore partitioned embedding_variable %s' % variable_name)
        from tensorflow.python.training import saver
        for part_var in variable:
          names_to_saveables = saver.BaseSaverBuilder.OpListToDict([part_var])
          saveable_objects = []
          for name, op in names_to_saveables.items():
            for s in saver.BaseSaverBuilder.SaveableObjectsForOp(op, name):
              saveable_objects.append(s)
          init_op = saveable_objects[0].restore([ckpt_path], None)
          part_var._initializer_op = init_op
      elif sok is not None and isinstance(variable, sok.DynamicVariable):
        print('restore dynamic_variable %s' % variable_name)
        keys, vals = load_embed_lib.load_kv_embed(
            task_index=hvd.rank(),
            task_num=hvd.size(),
            embed_dim=variable._dimension,
            var_name='embed-' + variable.name.replace('/', '__'),
            ckpt_path=ckpt_path)
        with ops.control_dependencies([variable._initializer_op]):
          variable._initializer_op = dynamic_variable_ops.dummy_var_assign(
              variable.handle, keys, vals)
      else:
        fail_restore_vars.append(variable_name)
    for variable_name in fail_restore_vars:
      if 'Momentum' not in variable_name:
        logging.warning('Variable [%s] is not available in checkpoint',
                        variable_name)

    tf.train.init_from_checkpoint(ckpt_path, vars_in_ckpt)

    if force_restore_shape_compatible:
      return estimator_utils.IncompatibleShapeRestoreHook(
          incompatible_shape_var_map)
    else:
      return None