in easy_rec/python/model/easy_rec_model.py [0:0]
def restore(self,
ckpt_path,
include_global_step=False,
ckpt_var_map_path='',
force_restore_shape_compatible=False):
"""Restore variables from ckpt_path.
steps:
1. list the variables in graph that need to be restored
2. inspect checkpoint and find the variables that could restore from checkpoint
substitute scope names in case necessary
3. call tf.train.init_from_checkpoint to restore the variables
Args:
ckpt_path: checkpoint path to restore from
include_global_step: whether to restore global_step variable
ckpt_var_map_path: variable map from graph variables to variables in a checkpoint
each line consists of: variable name in graph variable name in ckpt
force_restore_shape_compatible: if variable shape is incompatible, clip or pad
variables in checkpoint, and then restore
Returns:
IncompatibleShapeRestoreHook if force_shape_compatible else None
"""
name2var_map = self._get_restore_vars(ckpt_var_map_path)
logging.info('start to restore from %s' % ckpt_path)
ckpt_reader = tf.train.NewCheckpointReader(ckpt_path)
ckpt_var2shape_map = ckpt_reader.get_variable_to_shape_map()
if not include_global_step:
ckpt_var2shape_map.pop(tf.GraphKeys.GLOBAL_STEP, None)
vars_in_ckpt = {}
incompatible_shape_var_map = {}
fail_restore_vars = []
for variable_name, variable in sorted(name2var_map.items()):
if variable_name in ckpt_var2shape_map:
print('restore %s' % variable_name)
ckpt_var_shape = ckpt_var2shape_map[variable_name]
if type(variable) == list:
shape_arr = [x.get_shape() for x in variable]
var_shape = list(shape_arr[0])
for x in shape_arr[1:]:
var_shape[0] += x[0]
var_shape = tensor_shape.TensorShape(var_shape)
variable = variables.PartitionedVariable(
variable_name,
var_shape,
variable[0].dtype,
variable,
partitions=[len(variable)] + [1] * (len(var_shape) - 1))
else:
var_shape = variable.shape.as_list()
if ckpt_var_shape == var_shape:
vars_in_ckpt[variable_name] = list(variable) if isinstance(
variable, variables.PartitionedVariable) else variable
elif len(ckpt_var_shape) == len(var_shape):
if force_restore_shape_compatible:
# create a variable compatible with checkpoint to restore
dtype = variable[0].dtype if isinstance(variable,
list) else variable.dtype
with tf.variable_scope('incompatible_shape_restore'):
tmp_var = tf.get_variable(
name=variable_name + '_T_E_M_P',
shape=ckpt_var_shape,
trainable=False,
# add to a special collection for easy reference
# by tf.get_collection('T_E_M_P_RESTROE')
collections=['T_E_M_P_RESTROE'],
dtype=dtype)
vars_in_ckpt[variable_name] = tmp_var
incompatible_shape_var_map[variable] = tmp_var
print('incompatible restore %s[%s, %s]' %
(variable_name, str(var_shape), str(ckpt_var_shape)))
else:
logging.warning(
'Variable [%s] is available in checkpoint, but '
'incompatible shape with model variable.', variable_name)
else:
logging.warning(
'Variable [%s] is available in checkpoint, but '
'incompatible shape dims with model variable.', variable_name)
elif 'EmbeddingVariable' in str(type(variable)):
if '%s-keys' % variable_name not in ckpt_var2shape_map:
continue
print('restore embedding_variable %s' % variable_name)
from tensorflow.python.training import saver
names_to_saveables = saver.BaseSaverBuilder.OpListToDict([variable])
saveable_objects = []
for name, op in names_to_saveables.items():
for s in saver.BaseSaverBuilder.SaveableObjectsForOp(op, name):
saveable_objects.append(s)
init_op = saveable_objects[0].restore([ckpt_path], None)
variable._initializer_op = init_op
elif type(variable) == list and 'EmbeddingVariable' in str(
type(variable[0])):
if '%s/part_0-keys' % variable_name not in ckpt_var2shape_map:
continue
print('restore partitioned embedding_variable %s' % variable_name)
from tensorflow.python.training import saver
for part_var in variable:
names_to_saveables = saver.BaseSaverBuilder.OpListToDict([part_var])
saveable_objects = []
for name, op in names_to_saveables.items():
for s in saver.BaseSaverBuilder.SaveableObjectsForOp(op, name):
saveable_objects.append(s)
init_op = saveable_objects[0].restore([ckpt_path], None)
part_var._initializer_op = init_op
elif sok is not None and isinstance(variable, sok.DynamicVariable):
print('restore dynamic_variable %s' % variable_name)
keys, vals = load_embed_lib.load_kv_embed(
task_index=hvd.rank(),
task_num=hvd.size(),
embed_dim=variable._dimension,
var_name='embed-' + variable.name.replace('/', '__'),
ckpt_path=ckpt_path)
with ops.control_dependencies([variable._initializer_op]):
variable._initializer_op = dynamic_variable_ops.dummy_var_assign(
variable.handle, keys, vals)
else:
fail_restore_vars.append(variable_name)
for variable_name in fail_restore_vars:
if 'Momentum' not in variable_name:
logging.warning('Variable [%s] is not available in checkpoint',
variable_name)
tf.train.init_from_checkpoint(ckpt_path, vars_in_ckpt)
if force_restore_shape_compatible:
return estimator_utils.IncompatibleShapeRestoreHook(
incompatible_shape_var_map)
else:
return None