def _get_restore_vars()

in easy_rec/python/model/easy_rec_model.py [0:0]


  def _get_restore_vars(self, ckpt_var_map_path):
    """Restore by specify variable map between graph variables and ckpt variables.

    Args:
      ckpt_var_map_path: variable map from graph variables to variables in a checkpoint
          each line consists of: variable name in graph  variable name in ckpt

    Returns:
      the list of variables which need to restore from checkpoint
    """
    # here must use global_variables, because variables such as moving_mean
    #  and moving_variance is usually not trainable in detection models
    all_vars = tf.global_variables()
    PARTITION_PATTERN = '/part_[0-9]+'
    VAR_SUFIX_PATTERN = ':[0-9]$'

    name2var = {}
    for one_var in all_vars:
      var_name = re.sub(VAR_SUFIX_PATTERN, '', one_var.name)
      if re.search(PARTITION_PATTERN,
                   var_name) and one_var._save_slice_info is not None:
        var_name = re.sub(PARTITION_PATTERN, '', var_name)
        is_part = True
      else:
        is_part = False
      if var_name in name2var:
        assert is_part, 'multiple vars: %s' % var_name
        name2var[var_name].append(one_var)
      else:
        name2var[var_name] = [one_var] if is_part else one_var

    if ckpt_var_map_path != '':
      if not gfile.Exists(ckpt_var_map_path):
        logging.warning('%s not exist' % ckpt_var_map_path)
        return name2var

      # load var map
      name_map = {}
      with gfile.GFile(ckpt_var_map_path, 'r') as fin:
        for one_line in fin:
          one_line = one_line.strip()
          line_tok = [x for x in one_line.split() if x != '']
          if len(line_tok) != 2:
            logging.warning('Failed to process: %s' % one_line)
            continue
          name_map[line_tok[0]] = line_tok[1]
      update_map = {}
      old_keys = []
      for var_name in name2var:
        if var_name in name_map:
          in_ckpt_name = name_map[var_name]
          update_map[in_ckpt_name] = name2var[var_name]
          old_keys.append(var_name)
      for tmp_key in old_keys:
        del name2var[tmp_key]
      name2var.update(update_map)
      return name2var
    else:
      var_filter, scope_update = self.get_restore_filter()
      if var_filter is not None:
        name2var = {
            var_name: name2var[var_name]
            for var in name2var
            if var_filter.keep(var.name)
        }
      # drop scope prefix if necessary
      if scope_update is not None:
        name2var = {
            scope_update(var_name): name2var[var_name] for var_name in name2var
        }
      return name2var