def main()

in tf-ner-poc/src/main/python/namefinder/namefinder.py [0:0]


def main():
    if len(sys.argv) != 5:
        print("Usage namefinder.py embedding_file train_file dev_file test_file")
        return

    word_dict, rev_word_dict, embeddings, vector_size = load_glove(sys.argv[1])

    name_finder = NameFinder(vector_size)

    sentences, labels, char_set = name_finder.load_data(word_dict, sys.argv[2])
    sentences_dev, labels_dev, char_set_dev = name_finder.load_data(word_dict, sys.argv[3])

    char_dict = {k: v for v, k in enumerate(char_set | char_set_dev)}

    embedding_ph, token_ids_ph, char_ids_ph, word_lengths_ph, sequence_lengths_ph, labels_ph, dropout_keep_prob, train_op \
        = name_finder.create_graph(len(char_set | char_set_dev), embeddings)

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=True))

    best_f1 = 0.0
    no_improvement = 0
    with sess.as_default():
        init = tf.global_variables_initializer()
        sess.run(init, feed_dict={embedding_ph: embeddings})

        batch_size = 20
        for epoch in range(100):
            print("Epoch " + str(epoch))

            for batch_index in range(floor(len(sentences) / batch_size)):
                if batch_index % 200 == 0:
                    print("batch_index " + str(batch_index))

                # mini_batch should also return char_ids and word length ...
                sentences_batch, chars_batch, word_length_batch, labels_batch, lengths = \
                    name_finder.mini_batch(char_dict, sentences, labels, batch_size, batch_index)

                feed_dict = {token_ids_ph: sentences_batch, char_ids_ph: chars_batch,
                             word_lengths_ph: word_length_batch, sequence_lengths_ph: lengths,
                             labels_ph: labels_batch, dropout_keep_prob: 0.5}

                train_op.run(feed_dict, sess)

            accs = []
            correct_preds, total_correct, total_preds = 0., 0., 0.
            for batch_index in range(floor(len(sentences_dev) / batch_size)):
                sentences_test_batch, chars_batch_test, word_length_batch_test, \
                labels_test_batch, length_test = name_finder.mini_batch(char_dict,
                                                                        sentences_dev,
                                                                        labels_dev,
                                                                        batch_size,
                                                                        batch_index)

                labels_pred, sequence_lengths = name_finder.predict_batch(
                    sess, token_ids_ph, char_ids_ph, word_lengths_ph, sequence_lengths_ph,
                    sentences_test_batch, chars_batch_test, word_length_batch_test, length_test, dropout_keep_prob)

                for lab, lab_pred, length in zip(labels_test_batch, labels_pred,
                                                 sequence_lengths):
                    lab = lab[:length]
                    lab_pred = lab_pred[:length]
                    accs += [a == b for (a, b) in zip(lab, lab_pred)]

                    lab_chunks = set(get_chunks(lab, name_finder.label_dict))
                    lab_pred_chunks = set(get_chunks(lab_pred, name_finder.label_dict))

                    correct_preds += len(lab_chunks & lab_pred_chunks)
                    total_preds += len(lab_pred_chunks)
                    total_correct += len(lab_chunks)

            p = correct_preds / total_preds if correct_preds > 0 else 0
            r = correct_preds / total_correct if correct_preds > 0 else 0
            f1 = 2 * p * r / (p + r) if correct_preds > 0 else 0
            acc = np.mean(accs)

            if f1 > best_f1:

                best_f1 = f1
                no_improvement = 0

                with TemporaryDirectory() as temp_dir:
                    temp_model_dir = temp_dir + "/model"

                    builder = tf.saved_model.builder.SavedModelBuilder(temp_model_dir)
                    builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING])
                    builder.save()

                    write_mapping(word_dict, temp_model_dir + '/word_dict.txt')
                    write_mapping(name_finder.label_dict, temp_model_dir + "/label_dict.txt")
                    write_mapping(char_dict, temp_model_dir + "/char_dict.txt")

                    zipf = zipfile.ZipFile("namefinder-" + str(epoch) + ".zip", 'w', zipfile.ZIP_DEFLATED)

                    for root, dirs, files in os.walk(temp_model_dir):
                        for file in files:
                            modelFile = os.path.join(root, file)
                            zipf.write(modelFile, arcname=os.path.relpath(modelFile, temp_model_dir))
            else:
                no_improvement += 1

            print("ACC " + str(acc))
            print("F1  " + str(f1) + "  P " + str(p) + "  R " + str(r))

            if no_improvement > 5:
                print("No further improvement. Stopping.")
                break