def main()

in src/main.py [0:0]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--trainfile",
                        help="The input train file wrt to train  dir", required=True)
    parser.add_argument("--traindir",
                        help="The input train  dir", default=os.environ.get("SM_CHANNEL_TRAIN", "."))

    parser.add_argument("--valfile",
                        help="The input val file wrt to val  dir", required=True)
    parser.add_argument("--valdir",
                        help="The input val dir", default=os.environ.get("SM_CHANNEL_VAL", "."))

    parser.add_argument("--classfile",
                        help="The classes txt file which is a list of classes for dbpedia", required=True)
    parser.add_argument("--classdir",
                        help="The class file  dir", default=os.environ.get("SM_CHANNEL_CLASS", "."))

    parser.add_argument("--outdir", help="The output dir", default=os.environ.get("SM_OUTPUT_DATA_DIR", "."))
    parser.add_argument("--modeldir", help="The model dir", default=os.environ.get("SM_MODEL_DIR", "."))
    parser.add_argument("--checkpointdir", help="The checkpoint dir", default=None)
    parser.add_argument("--checkpointfreq",
                        help="The checkpoint frequency, only applies if the checkpoint dir is set", default=1)

    parser.add_argument("--earlystoppingpatience", help="The number of patience epochs", type=int,
                        default=10)
    parser.add_argument("--epochs", help="The number of epochs", type=int, default=10)
    parser.add_argument("--gradaccumulation", help="The number of gradient accumulation steps", type=int, default=1)
    parser.add_argument("--batch", help="The batchsize", type=int, default=32)

    parser.add_argument("--lr", help="The learning rate", type=float, default=0.0001)
    parser.add_argument("--finetune", help="Fine tunes the final layer (classifier) model instead of the entire model",
                        type=int, default=0, choices={1, 0})
    parser.add_argument("--maxseqlen",
                        help="The max sequence len, any input that is greater than this will be truncated and fed into the network. If too large, the the bert model will not support it or you will end up Cuda OOM error! ",
                        type=int, default=256)

    parser.add_argument("--log-level", help="Log level", default="INFO", choices={"INFO", "WARN", "DEBUG", "ERROR"})
    args = parser.parse_args()

    # Set up logging
    logging.basicConfig(level=logging.getLevelName(args.log_level), handlers=[logging.StreamHandler(sys.stdout)],
                        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')

    print(args.__dict__)

    train_data_file = os.path.join(args.traindir, args.trainfile)
    val_data_file = os.path.join(args.valdir, args.valfile)
    classes_file = os.path.join(args.classdir, args.classfile)

    b = Builder(train_data=train_data_file, val_data=val_data_file, labels_file=classes_file,
                checkpoint_dir=args.checkpointdir, epochs=args.epochs,
                early_stopping_patience=args.earlystoppingpatience, batch_size=args.batch, max_seq_len=args.maxseqlen,
                learning_rate=args.lr, fine_tune=args.finetune, model_dir=args.modeldir)

    trainer = b.get_trainer()

    # Persist mapper so it case be used in inference
    label_mapper_pickle_file = os.path.join(args.modeldir, "label_mapper.pkl")
    with open(label_mapper_pickle_file, "wb") as f:
        pickle.dump(b.get_label_mapper(), f)

    # Persist tokensier
    preprocessor_pickle_file = os.path.join(args.modeldir, "preprocessor.pkl")
    with open(preprocessor_pickle_file, "wb") as f:
        pickle.dump(b.get_preprocessor(), f)

    # Run training
    train_dataloader, val_dataloader = b.get_train_val_dataloader()
    trainer.run_train(train_iter=train_dataloader,
                      validation_iter=val_dataloader,
                      model_network=b.get_network(),
                      loss_function=b.get_loss_function(),
                      optimizer=b.get_optimiser(), pos_label=b.get_pos_label_index())