courses/DSL/challenge-mlprep/fraud_detection/trainer/task.py (52 lines of code) (raw):

"""Argument definitions for model training code in `trainer.model`.""" import argparse from trainer import model if __name__ == "__main__": # Parse input arguments parser = argparse.ArgumentParser() parser.add_argument( "--batch_size", help="Batch size for training steps", type=int, default=32, ) parser.add_argument( "--eval_data_path", help="GCS location pattern of eval files", required=True, ) parser.add_argument( "--num_bins", help="Number of buckets for float-valued fields", type=int, default=10, ) parser.add_argument( "--hash_bkts", help="Number of hash buckets for id fields", type=int, default=10, ) parser.add_argument( "--num_evals", help="Number of times to evaluate model on eval data training.", type=int, default=5, ) parser.add_argument( "--num_examples_to_train_on", help="Number of examples to train on.", type=int, default=100, ) parser.add_argument( "--output_dir", help="GCS location to write checkpoints and export models", required=True, ) parser.add_argument( "--train_data_path", help="GCS location pattern of train files containing eval URLs", required=True, ) args = parser.parse_args() hparams = args.__dict__ # Pass input arguments to function to train and evaluate model model.train_and_evaluate(hparams)