in nmt/model_helper.py [0:0]
def avg_checkpoints(model_dir, num_last_checkpoints, global_step,
global_step_name):
"""Average the last N checkpoints in the model_dir."""
checkpoint_state = tf.train.get_checkpoint_state(model_dir)
if not checkpoint_state:
utils.print_out("# No checkpoint file found in directory: %s" % model_dir)
return None
# Checkpoints are ordered from oldest to newest.
checkpoints = (
checkpoint_state.all_model_checkpoint_paths[-num_last_checkpoints:])
if len(checkpoints) < num_last_checkpoints:
utils.print_out(
"# Skipping averaging checkpoints because not enough checkpoints is "
"avaliable."
)
return None
avg_model_dir = os.path.join(model_dir, "avg_checkpoints")
if not tf.gfile.Exists(avg_model_dir):
utils.print_out(
"# Creating new directory %s for saving averaged checkpoints." %
avg_model_dir)
tf.gfile.MakeDirs(avg_model_dir)
utils.print_out("# Reading and averaging variables in checkpoints:")
var_list = tf.contrib.framework.list_variables(checkpoints[0])
var_values, var_dtypes = {}, {}
for (name, shape) in var_list:
if name != global_step_name:
var_values[name] = np.zeros(shape)
for checkpoint in checkpoints:
utils.print_out(" %s" % checkpoint)
reader = tf.contrib.framework.load_checkpoint(checkpoint)
for name in var_values:
tensor = reader.get_tensor(name)
var_dtypes[name] = tensor.dtype
var_values[name] += tensor
for name in var_values:
var_values[name] /= len(checkpoints)
# Build a graph with same variables in the checkpoints, and save the averaged
# variables into the avg_model_dir.
with tf.Graph().as_default():
tf_vars = [
tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[name])
for v in var_values
]
placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
global_step_var = tf.Variable(
global_step, name=global_step_name, trainable=False)
saver = tf.train.Saver(tf.all_variables())
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
for p, assign_op, (name, value) in zip(placeholders, assign_ops,
six.iteritems(var_values)):
sess.run(assign_op, {p: value})
# Use the built saver to save the averaged checkpoint. Only keep 1
# checkpoint and the best checkpoint will be moved to avg_best_metric_dir.
saver.save(
sess,
os.path.join(avg_model_dir, "translate.ckpt"))
return avg_model_dir