def main()

in tf-ner-poc/src/main/python/namecat/namecat.py [0:0]


def main():

    if len(sys.argv) != 4:
        print("Usage namecat.py train_file dev_file test_file")
        return

    labels_train, names_train = load_data(sys.argv[1])
    labels_dev, names_dev = load_data(sys.argv[2])
    labels_test, names_test = load_data(sys.argv[3])

    # Encode labels into ids
    label_dict = {}
    for label in labels_train:
        if not label in label_dict:
            label_dict[label] = len(label_dict)

    # Create char dict from names ...

    char_set = set()
    for name in names_train + names_dev + names_train:
        char_set = char_set.union(name)

    char_dict = {k: v for v, k in enumerate(char_set)}
    char_dict[chr(0)] = 0

    dropout_keep_prob, char_ids_ph, name_lengths_ph, y_ph = create_placeholders()

    train_op, probs_op = create_graph(dropout_keep_prob, char_ids_ph, name_lengths_ph, y_ph, len(char_set), len(label_dict))

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=True))

    with sess.as_default():
        init=tf.global_variables_initializer()
        sess.run(init)

        batch_size = 20
        for epoch in range(20):
            print("Epoch " + str(epoch))
            acc_train = []

            batch_indexes = list(range(floor(len(names_train) / batch_size)))

            # Shuffle the data
            combined = list(zip(names_train, labels_train))
            random.shuffle(combined)
            names_train[:], labels_train[:] = zip(*combined)

            for batch_index in batch_indexes:
                label_train_batch, name_train_batch, name_train_length = \
                    mini_batch(label_dict, char_dict, labels_train, names_train, batch_size, batch_index)

                # Add char dropout here ...
                for i, j in np.ndindex(name_train_batch.shape):
                    if random.uniform(0, 1) <= 0.0005:
                        name_train_batch[i][j] = 0

                feed_dict = {dropout_keep_prob: 0.5, char_ids_ph: name_train_batch, name_lengths_ph: name_train_length, y_ph: label_train_batch}
                _, probs = sess.run([train_op, probs_op], feed_dict)

                acc_train.append((batch_size - np.sum(np.minimum(np.abs(label_train_batch - np.argmax(probs, axis=1)),
                                                                        np.full((batch_size), 1)))) / batch_size)

            print("Train acc: " + str(np.mean(acc_train)))

            acc_dev = []
            for batch_index in range(floor(len(names_dev) / batch_size)):
                label_dev_batch, name_dev_batch, name_dev_length = \
                    mini_batch(label_dict, char_dict, labels_dev, names_dev, batch_size, batch_index)

                feed_dict = {dropout_keep_prob: 1, char_ids_ph: name_dev_batch, name_lengths_ph: name_dev_length, y_ph: label_dev_batch}
                probs = sess.run(probs_op, feed_dict)

                acc_dev.append((batch_size - np.sum(np.minimum(np.abs(label_dev_batch - np.argmax(probs, axis=1)),
                                                               np.full((batch_size), 1)))) / batch_size)

            print("Dev acc: " + str(np.mean(acc_dev)))

        #acc_test = []
        #for batch_index in range(floor(len(names_test) / batch_size)):
        #    label_test_batch, name_test_batch, name_test_length = \
        #        mini_batch(label_dict, char_dict, labels_test, names_test, batch_size, batch_index)

        #    feed_dict = {char_ids_ph: name_test_batch, name_lengths_ph: name_test_length, y_ph: label_test_batch}
        #    probs = sess.run(probs_op, feed_dict)

        #    acc_test.append((batch_size  - np.sum(np.abs(label_test_batch - np.argmax(probs, axis=1)))) / batch_size)

        #print("Test acc: " + str(np.mean(acc_test)))

        with TemporaryDirectory() as temp_dir:
            temp_model_dir = temp_dir + "/model"

            builder = tf.saved_model.builder.SavedModelBuilder(temp_model_dir)
            builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING])
            builder.save()

            write_mapping(label_dict, temp_model_dir + "/label_dict.txt")
            write_mapping(char_dict, temp_model_dir + "/char_dict.txt")

            zipf = zipfile.ZipFile("namecat-" + str(epoch) +".zip", 'w', zipfile.ZIP_DEFLATED)

            for root, dirs, files in os.walk(temp_model_dir):
                for file in files:
                    modelFile = os.path.join(root, file)
                    zipf.write(modelFile, arcname=os.path.relpath(modelFile, temp_model_dir))