in easy_rec/python/model/easy_rec_model.py [0:0]
def _get_restore_vars(self, ckpt_var_map_path):
"""Restore by specify variable map between graph variables and ckpt variables.
Args:
ckpt_var_map_path: variable map from graph variables to variables in a checkpoint
each line consists of: variable name in graph variable name in ckpt
Returns:
the list of variables which need to restore from checkpoint
"""
# here must use global_variables, because variables such as moving_mean
# and moving_variance is usually not trainable in detection models
all_vars = tf.global_variables()
PARTITION_PATTERN = '/part_[0-9]+'
VAR_SUFIX_PATTERN = ':[0-9]$'
name2var = {}
for one_var in all_vars:
var_name = re.sub(VAR_SUFIX_PATTERN, '', one_var.name)
if re.search(PARTITION_PATTERN,
var_name) and one_var._save_slice_info is not None:
var_name = re.sub(PARTITION_PATTERN, '', var_name)
is_part = True
else:
is_part = False
if var_name in name2var:
assert is_part, 'multiple vars: %s' % var_name
name2var[var_name].append(one_var)
else:
name2var[var_name] = [one_var] if is_part else one_var
if ckpt_var_map_path != '':
if not gfile.Exists(ckpt_var_map_path):
logging.warning('%s not exist' % ckpt_var_map_path)
return name2var
# load var map
name_map = {}
with gfile.GFile(ckpt_var_map_path, 'r') as fin:
for one_line in fin:
one_line = one_line.strip()
line_tok = [x for x in one_line.split() if x != '']
if len(line_tok) != 2:
logging.warning('Failed to process: %s' % one_line)
continue
name_map[line_tok[0]] = line_tok[1]
update_map = {}
old_keys = []
for var_name in name2var:
if var_name in name_map:
in_ckpt_name = name_map[var_name]
update_map[in_ckpt_name] = name2var[var_name]
old_keys.append(var_name)
for tmp_key in old_keys:
del name2var[tmp_key]
name2var.update(update_map)
return name2var
else:
var_filter, scope_update = self.get_restore_filter()
if var_filter is not None:
name2var = {
var_name: name2var[var_name]
for var in name2var
if var_filter.keep(var.name)
}
# drop scope prefix if necessary
if scope_update is not None:
name2var = {
scope_update(var_name): name2var[var_name] for var_name in name2var
}
return name2var