in tf-ner-poc/src/main/python/normalizer/normalizer.py [0:0]
def main():
if len(sys.argv) != 4:
print("Usage normalizer.py train_file dev_file test_file")
return
checkpoints_path = "/tmp/model/checkpoints"
source_train, target_train = load_data(sys.argv[1])
source_dev, target_dev = load_data(sys.argv[2])
source_test, target_test = load_data(sys.argv[3])
source_char_dict = encode_chars(source_train + source_dev + source_test)
source_char_dict[chr(0)] = 0
target_char_dict = encode_chars(target_train + target_dev + target_test)
# char id 2 is STX (Start of Text), and 3 ETX (End of Text)
target_char_dict[chr(2)] = len(target_char_dict)
target_char_dict[chr(3)] = len(target_char_dict)
target_dict_rev = {v: k for k, v in target_char_dict.items()}
batch_size = 20
target_max_len = -1
for token in (target_train + target_dev + target_test):
target_max_len = max(target_max_len, len(token))
# Increase size by one for termination char
target_max_len += 1
train_graph = tf.Graph()
eval_graph = tf.Graph()
with train_graph.as_default():
t_encoder_char_ids_ph, t_encoder_lengths_ph, t_decoder_char_ids_ph, t_decoder_lengths, t_adam_optimize, t_train_prediction, t_dec_out = \
create_graph("TRAIN", batch_size, len(source_char_dict), target_max_len, len(target_char_dict))
train_saver = tf.train.Saver()
train_sess = tf.Session()
train_sess.run(tf.global_variables_initializer())
with eval_graph.as_default():
e_encoder_char_ids_ph, e_encoder_lengths_ph, e_dec_out = \
create_graph("EVAL", batch_size, len(source_char_dict), target_max_len, len(target_char_dict))
eval_saver = tf.train.Saver()
eval_sess = tf.Session(graph=eval_graph)
for epoch in range(20):
print("Epoch " + str(epoch))
with train_graph.as_default():
for batch_index in range(floor(len(source_train) / batch_size)):
if batch_index > 0 and batch_index % 100 == 0:
print("batch_index " + str(batch_index))
target_batch, target_length, source_batch, source_length = \
mini_batch(target_char_dict, target_train, source_char_dict, source_train, batch_size, batch_index)
# TODO: Add char dropout here ...
for i, j in np.ndindex(source_batch.shape):
if random.uniform(0, 1) <= 0.0005:
source_batch[i][j] = 0
feed_dict = {t_encoder_lengths_ph: source_length, t_encoder_char_ids_ph: source_batch,
t_decoder_lengths: target_length, t_decoder_char_ids_ph: target_batch}
t1, dec1 = train_sess.run([t_adam_optimize, t_dec_out], feed_dict)
dec2 = train_sess.run([t_dec_out], feed_dict)
tv=1
# Save train model, and restore it into the eval session
checkpoint_path = train_saver.save(train_sess, checkpoints_path, global_step=epoch)
eval_saver.restore(eval_sess, checkpoint_path)
with eval_graph.as_default():
count_correct = 0
for batch_index in range(floor(len(source_dev) / batch_size)):
target_batch, target_length, source_batch, source_length = \
mini_batch(target_char_dict, target_dev, source_char_dict, source_dev, batch_size, batch_index)
begin = batch_index
end = min(batch_index + batch_size, len(source_dev))
target_strings = target_dev[begin:end]
feed_dict = {e_encoder_lengths_ph: source_length, e_encoder_char_ids_ph: source_batch}
result = eval_sess.run(e_dec_out, feed_dict)
decoded_dates = []
for coded_date in result:
date = ""
for char_id in coded_date:
if not char_id == len(target_char_dict) - 1:
date = date + (target_dict_rev[char_id])
decoded_dates.append(date)
for i in range(len(target_strings)):
if target_strings[i] == decoded_dates[i]:
count_correct = count_correct + 1
print("Dev: " + str(count_correct / len(target_dev)))
with TemporaryDirectory() as temp_dir:
temp_model_dir = temp_dir + "/model"
with eval_graph.as_default():
builder = tf.saved_model.builder.SavedModelBuilder(temp_model_dir)
builder.add_meta_graph_and_variables(eval_sess, [tf.saved_model.tag_constants.SERVING])
builder.save()
write_mapping(source_char_dict, temp_model_dir + '/source_char_dict.txt')
write_mapping(target_char_dict, temp_model_dir + '/target_char_dict.txt')
zipf = zipfile.ZipFile("normalizer.zip", 'w', zipfile.ZIP_DEFLATED)
for root, dirs, files in os.walk(temp_model_dir):
for file in files:
modelFile = os.path.join(root, file)
zipf.write(modelFile, arcname=os.path.relpath(modelFile, temp_model_dir))