def main()

in tf-ner-poc/src/main/python/normalizer/normalizer.py [0:0]


def main():

    if len(sys.argv) != 4:
        print("Usage normalizer.py train_file dev_file test_file")
        return

    checkpoints_path = "/tmp/model/checkpoints"

    source_train, target_train = load_data(sys.argv[1])
    source_dev, target_dev = load_data(sys.argv[2])
    source_test, target_test = load_data(sys.argv[3])

    source_char_dict = encode_chars(source_train + source_dev + source_test)
    source_char_dict[chr(0)] = 0

    target_char_dict = encode_chars(target_train + target_dev + target_test)

    # char id 2 is STX (Start of Text), and 3 ETX (End of Text)
    target_char_dict[chr(2)] = len(target_char_dict)
    target_char_dict[chr(3)] = len(target_char_dict)

    target_dict_rev = {v: k for k, v in target_char_dict.items()}

    batch_size = 20

    target_max_len = -1
    for token in (target_train + target_dev + target_test):
        target_max_len = max(target_max_len, len(token))

    # Increase size by one for termination char
    target_max_len += 1

    train_graph = tf.Graph()
    eval_graph = tf.Graph()

    with train_graph.as_default():
        t_encoder_char_ids_ph, t_encoder_lengths_ph, t_decoder_char_ids_ph, t_decoder_lengths, t_adam_optimize, t_train_prediction, t_dec_out = \
            create_graph("TRAIN", batch_size, len(source_char_dict), target_max_len, len(target_char_dict))
        train_saver = tf.train.Saver()
        train_sess = tf.Session()
        train_sess.run(tf.global_variables_initializer())

    with eval_graph.as_default():
        e_encoder_char_ids_ph, e_encoder_lengths_ph, e_dec_out = \
            create_graph("EVAL", batch_size, len(source_char_dict), target_max_len, len(target_char_dict))
        eval_saver = tf.train.Saver()

        eval_sess = tf.Session(graph=eval_graph)

    for epoch in range(20):
        print("Epoch " + str(epoch))

        with train_graph.as_default():
            for batch_index in range(floor(len(source_train) / batch_size)):
                if batch_index > 0 and batch_index % 100 == 0:
                    print("batch_index " + str(batch_index))

                target_batch, target_length, source_batch, source_length = \
                    mini_batch(target_char_dict, target_train, source_char_dict, source_train, batch_size, batch_index)

                # TODO: Add char dropout here ...
                for i, j in np.ndindex(source_batch.shape):
                    if random.uniform(0, 1) <= 0.0005:
                        source_batch[i][j] = 0

                feed_dict = {t_encoder_lengths_ph: source_length, t_encoder_char_ids_ph: source_batch,
                             t_decoder_lengths: target_length, t_decoder_char_ids_ph: target_batch}

                t1, dec1 = train_sess.run([t_adam_optimize, t_dec_out], feed_dict)
                dec2 = train_sess.run([t_dec_out], feed_dict)
                tv=1

            # Save train model, and restore it into the eval session
            checkpoint_path = train_saver.save(train_sess, checkpoints_path, global_step=epoch)
            eval_saver.restore(eval_sess, checkpoint_path)

        with eval_graph.as_default():
            count_correct = 0
            for batch_index in range(floor(len(source_dev) / batch_size)):
                target_batch, target_length, source_batch, source_length = \
                    mini_batch(target_char_dict, target_dev, source_char_dict, source_dev, batch_size, batch_index)

                begin = batch_index
                end = min(batch_index + batch_size, len(source_dev))
                target_strings = target_dev[begin:end]

                feed_dict = {e_encoder_lengths_ph: source_length, e_encoder_char_ids_ph: source_batch}
                result = eval_sess.run(e_dec_out, feed_dict)

                decoded_dates = []

                for coded_date in result:
                    date = ""
                    for char_id in coded_date:
                        if not char_id == len(target_char_dict) - 1:
                            date = date + (target_dict_rev[char_id])
                    decoded_dates.append(date)

                for i in range(len(target_strings)):
                    if target_strings[i] == decoded_dates[i]:
                        count_correct = count_correct + 1

            print("Dev: " + str(count_correct / len(target_dev)))

    with TemporaryDirectory() as temp_dir:

        temp_model_dir = temp_dir + "/model"


        with eval_graph.as_default():
            builder = tf.saved_model.builder.SavedModelBuilder(temp_model_dir)
            builder.add_meta_graph_and_variables(eval_sess, [tf.saved_model.tag_constants.SERVING])
            builder.save()

        write_mapping(source_char_dict, temp_model_dir + '/source_char_dict.txt')
        write_mapping(target_char_dict, temp_model_dir + '/target_char_dict.txt')

        zipf = zipfile.ZipFile("normalizer.zip", 'w', zipfile.ZIP_DEFLATED)

        for root, dirs, files in os.walk(temp_model_dir):
            for file in files:
                modelFile = os.path.join(root, file)
                zipf.write(modelFile, arcname=os.path.relpath(modelFile, temp_model_dir))