def create_emb_for_encoder_and_decoder()

in nmt/model_helper.py [0:0]


def create_emb_for_encoder_and_decoder(share_vocab,
                                       src_vocab_size,
                                       tgt_vocab_size,
                                       src_embed_size,
                                       tgt_embed_size,
                                       dtype=tf.float32,
                                       num_enc_partitions=0,
                                       num_dec_partitions=0,
                                       src_vocab_file=None,
                                       tgt_vocab_file=None,
                                       src_embed_file=None,
                                       tgt_embed_file=None,
                                       use_char_encode=False,
                                       scope=None):
  """Create embedding matrix for both encoder and decoder.

  Args:
    share_vocab: A boolean. Whether to share embedding matrix for both
      encoder and decoder.
    src_vocab_size: An integer. The source vocab size.
    tgt_vocab_size: An integer. The target vocab size.
    src_embed_size: An integer. The embedding dimension for the encoder's
      embedding.
    tgt_embed_size: An integer. The embedding dimension for the decoder's
      embedding.
    dtype: dtype of the embedding matrix. Default to float32.
    num_enc_partitions: number of partitions used for the encoder's embedding
      vars.
    num_dec_partitions: number of partitions used for the decoder's embedding
      vars.
    scope: VariableScope for the created subgraph. Default to "embedding".

  Returns:
    embedding_encoder: Encoder's embedding matrix.
    embedding_decoder: Decoder's embedding matrix.

  Raises:
    ValueError: if use share_vocab but source and target have different vocab
      size.
  """
  if num_enc_partitions <= 1:
    enc_partitioner = None
  else:
    # Note: num_partitions > 1 is required for distributed training due to
    # embedding_lookup tries to colocate single partition-ed embedding variable
    # with lookup ops. This may cause embedding variables being placed on worker
    # jobs.
    enc_partitioner = tf.fixed_size_partitioner(num_enc_partitions)

  if num_dec_partitions <= 1:
    dec_partitioner = None
  else:
    # Note: num_partitions > 1 is required for distributed training due to
    # embedding_lookup tries to colocate single partition-ed embedding variable
    # with lookup ops. This may cause embedding variables being placed on worker
    # jobs.
    dec_partitioner = tf.fixed_size_partitioner(num_dec_partitions)

  if src_embed_file and enc_partitioner:
    raise ValueError(
        "Can't set num_enc_partitions > 1 when using pretrained encoder "
        "embedding")

  if tgt_embed_file and dec_partitioner:
    raise ValueError(
        "Can't set num_dec_partitions > 1 when using pretrained decdoer "
        "embedding")

  with tf.variable_scope(
      scope or "embeddings", dtype=dtype, partitioner=enc_partitioner) as scope:
    # Share embedding
    if share_vocab:
      if src_vocab_size != tgt_vocab_size:
        raise ValueError("Share embedding but different src/tgt vocab sizes"
                         " %d vs. %d" % (src_vocab_size, tgt_vocab_size))
      assert src_embed_size == tgt_embed_size
      utils.print_out("# Use the same embedding for source and target")
      vocab_file = src_vocab_file or tgt_vocab_file
      embed_file = src_embed_file or tgt_embed_file

      embedding_encoder = _create_or_load_embed(
          "embedding_share", vocab_file, embed_file,
          src_vocab_size, src_embed_size, dtype)
      embedding_decoder = embedding_encoder
    else:
      if not use_char_encode:
        with tf.variable_scope("encoder", partitioner=enc_partitioner):
          embedding_encoder = _create_or_load_embed(
              "embedding_encoder", src_vocab_file, src_embed_file,
              src_vocab_size, src_embed_size, dtype)
      else:
        embedding_encoder = None

      with tf.variable_scope("decoder", partitioner=dec_partitioner):
        embedding_decoder = _create_or_load_embed(
            "embedding_decoder", tgt_vocab_file, tgt_embed_file,
            tgt_vocab_size, tgt_embed_size, dtype)

  return embedding_encoder, embedding_decoder