in nmt/model.py [0:0]
def _set_train_or_infer(self, res, reverse_target_vocab_table, hparams):
"""Set up training and inference."""
if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
self.train_loss = res[1]
self.word_count = tf.reduce_sum(
self.iterator.source_sequence_length) + tf.reduce_sum(
self.iterator.target_sequence_length)
elif self.mode == tf.contrib.learn.ModeKeys.EVAL:
self.eval_loss = res[1]
elif self.mode == tf.contrib.learn.ModeKeys.INFER:
self.infer_logits, _, self.final_context_state, self.sample_id = res
self.sample_words = reverse_target_vocab_table.lookup(
tf.to_int64(self.sample_id))
if self.mode != tf.contrib.learn.ModeKeys.INFER:
## Count the number of predicted words for compute ppl.
self.predict_count = tf.reduce_sum(
self.iterator.target_sequence_length)
params = tf.trainable_variables()
# Gradients and SGD update operation for training the model.
# Arrange for the embedding vars to appear at the beginning.
if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
self.learning_rate = tf.constant(hparams.learning_rate)
# warm-up
self.learning_rate = self._get_learning_rate_warmup(hparams)
# decay
self.learning_rate = self._get_learning_rate_decay(hparams)
# Optimizer
if hparams.optimizer == "sgd":
opt = tf.train.GradientDescentOptimizer(self.learning_rate)
elif hparams.optimizer == "adam":
opt = tf.train.AdamOptimizer(self.learning_rate)
else:
raise ValueError("Unknown optimizer type %s" % hparams.optimizer)
# Gradients
gradients = tf.gradients(
self.train_loss,
params,
colocate_gradients_with_ops=hparams.colocate_gradients_with_ops)
clipped_grads, grad_norm_summary, grad_norm = model_helper.gradient_clip(
gradients, max_gradient_norm=hparams.max_gradient_norm)
self.grad_norm_summary = grad_norm_summary
self.grad_norm = grad_norm
self.update = opt.apply_gradients(
zip(clipped_grads, params), global_step=self.global_step)
# Summary
self.train_summary = self._get_train_summary()
elif self.mode == tf.contrib.learn.ModeKeys.INFER:
self.infer_summary = self._get_infer_summary(hparams)
# Print trainable variables
utils.print_out("# Trainable variables")
utils.print_out("Format: <name>, <shape>, <(soft) device placement>")
for param in params:
utils.print_out(" %s, %s, %s" % (param.name, str(param.get_shape()),
param.op.device))