in nmt/model_helper.py [0:0]
def _single_cell(unit_type, num_units, forget_bias, dropout, mode,
residual_connection=False, device_str=None, residual_fn=None):
"""Create an instance of a single RNN cell."""
# dropout (= 1 - keep_prob) is set to 0 during eval and infer
dropout = dropout if mode == tf.contrib.learn.ModeKeys.TRAIN else 0.0
# Cell Type
if unit_type == "lstm":
utils.print_out(" LSTM, forget_bias=%g" % forget_bias, new_line=False)
single_cell = tf.contrib.rnn.BasicLSTMCell(
num_units,
forget_bias=forget_bias)
elif unit_type == "gru":
utils.print_out(" GRU", new_line=False)
single_cell = tf.contrib.rnn.GRUCell(num_units)
elif unit_type == "layer_norm_lstm":
utils.print_out(" Layer Normalized LSTM, forget_bias=%g" % forget_bias,
new_line=False)
single_cell = tf.contrib.rnn.LayerNormBasicLSTMCell(
num_units,
forget_bias=forget_bias,
layer_norm=True)
elif unit_type == "nas":
utils.print_out(" NASCell", new_line=False)
single_cell = tf.contrib.rnn.NASCell(num_units)
else:
raise ValueError("Unknown unit type %s!" % unit_type)
# Dropout (= 1 - keep_prob)
if dropout > 0.0:
single_cell = tf.contrib.rnn.DropoutWrapper(
cell=single_cell, input_keep_prob=(1.0 - dropout))
utils.print_out(" %s, dropout=%g " %(type(single_cell).__name__, dropout),
new_line=False)
# Residual
if residual_connection:
single_cell = tf.contrib.rnn.ResidualWrapper(
single_cell, residual_fn=residual_fn)
utils.print_out(" %s" % type(single_cell).__name__, new_line=False)
# Device Wrapper
if device_str:
single_cell = tf.contrib.rnn.DeviceWrapper(single_cell, device_str)
utils.print_out(" %s, device=%s" %
(type(single_cell).__name__, device_str), new_line=False)
return single_cell