in pytorch_alternatives/custom_pytorch_nlp/src/main.py [0:0]
def parse_args():
"""Acquire hyperparameters and directory locations passed by SageMaker"""
parser = argparse.ArgumentParser()
# Hyperparameters sent by the client are passed as command-line arguments to the script.
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--learning_rate", type=float, default=0.001)
parser.add_argument("--num_classes", type=int, default=4)
parser.add_argument("--max_seq_len", type=int, default=40)
# Data, model, and output directories
parser.add_argument("--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR"))
parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR"))
parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN"))
parser.add_argument("--test", type=str, default=os.environ.get("SM_CHANNEL_TEST"))
parser.add_argument("--embeddings", type=str, default=os.environ.get("SM_CHANNEL_EMBEDDINGS"))
return parser.parse_known_args()