in tf-ner-poc/src/main/python/namefinder/namefinder.py [0:0]
def main():
if len(sys.argv) != 5:
print("Usage namefinder.py embedding_file train_file dev_file test_file")
return
word_dict, rev_word_dict, embeddings, vector_size = load_glove(sys.argv[1])
name_finder = NameFinder(vector_size)
sentences, labels, char_set = name_finder.load_data(word_dict, sys.argv[2])
sentences_dev, labels_dev, char_set_dev = name_finder.load_data(word_dict, sys.argv[3])
char_dict = {k: v for v, k in enumerate(char_set | char_set_dev)}
embedding_ph, token_ids_ph, char_ids_ph, word_lengths_ph, sequence_lengths_ph, labels_ph, dropout_keep_prob, train_op \
= name_finder.create_graph(len(char_set | char_set_dev), embeddings)
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
log_device_placement=True))
best_f1 = 0.0
no_improvement = 0
with sess.as_default():
init = tf.global_variables_initializer()
sess.run(init, feed_dict={embedding_ph: embeddings})
batch_size = 20
for epoch in range(100):
print("Epoch " + str(epoch))
for batch_index in range(floor(len(sentences) / batch_size)):
if batch_index % 200 == 0:
print("batch_index " + str(batch_index))
# mini_batch should also return char_ids and word length ...
sentences_batch, chars_batch, word_length_batch, labels_batch, lengths = \
name_finder.mini_batch(char_dict, sentences, labels, batch_size, batch_index)
feed_dict = {token_ids_ph: sentences_batch, char_ids_ph: chars_batch,
word_lengths_ph: word_length_batch, sequence_lengths_ph: lengths,
labels_ph: labels_batch, dropout_keep_prob: 0.5}
train_op.run(feed_dict, sess)
accs = []
correct_preds, total_correct, total_preds = 0., 0., 0.
for batch_index in range(floor(len(sentences_dev) / batch_size)):
sentences_test_batch, chars_batch_test, word_length_batch_test, \
labels_test_batch, length_test = name_finder.mini_batch(char_dict,
sentences_dev,
labels_dev,
batch_size,
batch_index)
labels_pred, sequence_lengths = name_finder.predict_batch(
sess, token_ids_ph, char_ids_ph, word_lengths_ph, sequence_lengths_ph,
sentences_test_batch, chars_batch_test, word_length_batch_test, length_test, dropout_keep_prob)
for lab, lab_pred, length in zip(labels_test_batch, labels_pred,
sequence_lengths):
lab = lab[:length]
lab_pred = lab_pred[:length]
accs += [a == b for (a, b) in zip(lab, lab_pred)]
lab_chunks = set(get_chunks(lab, name_finder.label_dict))
lab_pred_chunks = set(get_chunks(lab_pred, name_finder.label_dict))
correct_preds += len(lab_chunks & lab_pred_chunks)
total_preds += len(lab_pred_chunks)
total_correct += len(lab_chunks)
p = correct_preds / total_preds if correct_preds > 0 else 0
r = correct_preds / total_correct if correct_preds > 0 else 0
f1 = 2 * p * r / (p + r) if correct_preds > 0 else 0
acc = np.mean(accs)
if f1 > best_f1:
best_f1 = f1
no_improvement = 0
with TemporaryDirectory() as temp_dir:
temp_model_dir = temp_dir + "/model"
builder = tf.saved_model.builder.SavedModelBuilder(temp_model_dir)
builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING])
builder.save()
write_mapping(word_dict, temp_model_dir + '/word_dict.txt')
write_mapping(name_finder.label_dict, temp_model_dir + "/label_dict.txt")
write_mapping(char_dict, temp_model_dir + "/char_dict.txt")
zipf = zipfile.ZipFile("namefinder-" + str(epoch) + ".zip", 'w', zipfile.ZIP_DEFLATED)
for root, dirs, files in os.walk(temp_model_dir):
for file in files:
modelFile = os.path.join(root, file)
zipf.write(modelFile, arcname=os.path.relpath(modelFile, temp_model_dir))
else:
no_improvement += 1
print("ACC " + str(acc))
print("F1 " + str(f1) + " P " + str(p) + " R " + str(r))
if no_improvement > 5:
print("No further improvement. Stopping.")
break