in tf-ner-poc/src/main/python/normalizer/normalizer.py [0:0]
def create_graph(mode, batch_size, encoder_nchars, max_target_length, decoder_nchars):
# Hyper parameters
encoder_char_dim = 100
num_units = 256
batch_size_ph = tf.placeholder_with_default(batch_size, shape=(), name="batch_size")
# Encoder
encoder_char_ids_ph = tf.placeholder(tf.int32, shape=[None, None], name="encoder_char_ids")
encoder_lengths_ph = tf.placeholder(tf.int32, shape=[None], name="encoder_lengths")
encoder_embedding_weights = tf.get_variable(name="char_embeddings", dtype=tf.float32,
shape=[encoder_nchars, encoder_char_dim])
encoder_emb_inp = tf.nn.embedding_lookup(encoder_embedding_weights, encoder_char_ids_ph)
if "TRAIN" == mode:
encoder_emb_inp = tf.nn.dropout(encoder_emb_inp, 0.7)
encoder_emb_inp = tf.transpose(encoder_emb_inp, perm=[1, 0, 2])
encoder_cell = tf.nn.rnn_cell.LSTMCell(num_units)
initial_state = encoder_cell.zero_state(batch_size_ph, dtype=tf.float32)
encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
encoder_cell, encoder_emb_inp, initial_state=initial_state,
sequence_length=encoder_lengths_ph,
time_major=True, swap_memory=True)
# Decoder
decoder_char_ids_ph = tf.placeholder(tf.int32, shape=[None, None], name="decoder_char_ids")
decoder_lengths = tf.placeholder(tf.int32, shape=[None], name="decoder_lengths")
# decoder output (decoder_input shifted to the left by one)
decoder_char_dim = 100
decoder_embedding_weights = tf.get_variable(name="decoder_char_embeddings", dtype=tf.float32,
shape=[decoder_nchars, decoder_char_dim])
projection_layer = tf.layers.Dense(units=decoder_nchars, use_bias=True) # To predict one output char at a time ...
attention_states = tf.transpose(encoder_outputs, [1, 0, 2])
attention_mechanism = tf.contrib.seq2seq.LuongAttention(
num_units, attention_states,
memory_sequence_length=encoder_lengths_ph)
decoder_cell = tf.nn.rnn_cell.LSTMCell(num_units)
decoder_cell = tf.contrib.seq2seq.AttentionWrapper(decoder_cell, attention_mechanism,
attention_layer_size=num_units)
# decoder_initial_state = encoder_state
decoder_initial_state = decoder_cell.zero_state(dtype=tf.float32, batch_size=batch_size_ph)
if "TRAIN" == mode:
decoder_input = tf.pad(decoder_char_ids_ph, tf.constant([[0,0], [1,0]]),
'CONSTANT', constant_values=(decoder_nchars-2))
decoder_emb_inp = tf.nn.embedding_lookup(decoder_embedding_weights, decoder_input)
decoder_emb_inp = tf.transpose(decoder_emb_inp, perm=[1, 0, 2])
helper = tf.contrib.seq2seq.TrainingHelper(
decoder_emb_inp, [max_target_length for _ in range(batch_size)], time_major=True)
decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper,
decoder_initial_state, output_layer=projection_layer)
outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, output_time_major=True, swap_memory=True )
logits = outputs.rnn_output
train_prediction = outputs.sample_id
decoder_output = tf.pad(tf.transpose(decoder_char_ids_ph, perm=[1, 0]), tf.constant([[0,1], [0,0]]),
'CONSTANT', constant_values=(decoder_nchars-1))
crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=decoder_output, logits=logits, name="crossent")
loss = tf.reduce_sum(crossent * tf.to_float(decoder_lengths)) / (batch_size * max_target_length)
# Optimizer
# TODO: Tutorial suggest to swap to SGD for alter iterations
# optimizer = tf.train.AdamOptimizer()
optimizer = tf.train.RMSPropOptimizer(learning_rate=0.001)
gradients, v = zip(*optimizer.compute_gradients(loss))
gradients, _ = tf.clip_by_global_norm(gradients, 10.0)
optimize = optimizer.apply_gradients(zip(gradients, v))
return encoder_char_ids_ph, encoder_lengths_ph, decoder_char_ids_ph, decoder_lengths, optimize, train_prediction, outputs
if "EVAL" == mode:
helperE = tf.contrib.seq2seq.GreedyEmbeddingHelper(
decoder_embedding_weights,
tf.fill([batch_size_ph], decoder_nchars-2), decoder_nchars-1)
decoderE = tf.contrib.seq2seq.BasicDecoder(
decoder_cell, helperE, decoder_initial_state,
output_layer=projection_layer)
outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoderE, maximum_iterations=20)
translations = tf.identity(outputs.sample_id, name="decode")
return encoder_char_ids_ph, encoder_lengths_ph, translations