in nmt/model.py [0:0]
def _set_params_initializer(self,
hparams,
mode,
iterator,
source_vocab_table,
target_vocab_table,
scope,
extra_args=None):
"""Set various params for self and initialize."""
assert isinstance(iterator, iterator_utils.BatchedInput)
self.iterator = iterator
self.mode = mode
self.src_vocab_table = source_vocab_table
self.tgt_vocab_table = target_vocab_table
self.src_vocab_size = hparams.src_vocab_size
self.tgt_vocab_size = hparams.tgt_vocab_size
self.num_gpus = hparams.num_gpus
self.time_major = hparams.time_major
if hparams.use_char_encode:
assert (not self.time_major), ("Can't use time major for"
" char-level inputs.")
self.dtype = tf.float32
self.num_sampled_softmax = hparams.num_sampled_softmax
# extra_args: to make it flexible for adding external customizable code
self.single_cell_fn = None
if extra_args:
self.single_cell_fn = extra_args.single_cell_fn
# Set num units
self.num_units = hparams.num_units
# Set num layers
self.num_encoder_layers = hparams.num_encoder_layers
self.num_decoder_layers = hparams.num_decoder_layers
assert self.num_encoder_layers
assert self.num_decoder_layers
# Set num residual layers
if hasattr(hparams, "num_residual_layers"): # compatible common_test_utils
self.num_encoder_residual_layers = hparams.num_residual_layers
self.num_decoder_residual_layers = hparams.num_residual_layers
else:
self.num_encoder_residual_layers = hparams.num_encoder_residual_layers
self.num_decoder_residual_layers = hparams.num_decoder_residual_layers
# Batch size
self.batch_size = tf.size(self.iterator.source_sequence_length)
# Global step
self.global_step = tf.Variable(0, trainable=False)
# Initializer
self.random_seed = hparams.random_seed
initializer = model_helper.get_initializer(
hparams.init_op, self.random_seed, hparams.init_weight)
tf.get_variable_scope().set_initializer(initializer)
# Embeddings
if extra_args and extra_args.encoder_emb_lookup_fn:
self.encoder_emb_lookup_fn = extra_args.encoder_emb_lookup_fn
else:
self.encoder_emb_lookup_fn = tf.nn.embedding_lookup
self.init_embeddings(hparams, scope)