def main()

in containers/training/resources/train.py [0:0]


def main():
    hyperparamters = json.loads(os.environ['SM_HPS'])
    model_dir = os.environ['SM_MODEL_DIR']
    log_dir = os.environ['SM_MODEL_DIR']
    train_data_dir = os.environ['SM_CHANNEL_TRAIN']
    
    if hyperparamters["tokenizer_download_model"] == "disable":
        tokenizer_model = os.environ['SM_CHANNEL_TOKENIZER']
    else:
        tokenizer_model = 'distilbert-base-uncased'
    
    #gpus_per_host = int(os.environ['SM_NUM_GPUS'])
    train_texts, train_labels = read_imdb_split(train_data_dir)
    train_texts, val_texts, train_labels, val_labels = train_test_split(train_texts, train_labels, test_size=.2)
    tokenizer = DistilBertTokenizerFast.from_pretrained(tokenizer_model)
    
    train_encodings = tokenizer(train_texts, truncation=True, padding=True)
    val_encodings = tokenizer(val_texts, truncation=True, padding=True)

    train_dataset = tf.data.Dataset.from_tensor_slices((
        dict(train_encodings),
        train_labels
    ))
    val_dataset = tf.data.Dataset.from_tensor_slices((
        dict(val_encodings),
        val_labels
    ))
    
    training_args = TFTrainingArguments(
        output_dir='./results',          # output directory
        num_train_epochs=hyperparamters["num_train_epochs"],              # total number of training epochs
        per_device_train_batch_size=hyperparamters["per_device_train_batch_size"],  # batch size per device during training
        per_device_eval_batch_size=hyperparamters["per_device_eval_batch_size"],   # batch size for evaluation
        warmup_steps=hyperparamters["warmup_steps"],
        weight_decay=hyperparamters["weight_decay"],               # strength of weight decay
        logging_dir=log_dir,            # directory for storing logs
        logging_steps=hyperparamters["logging_steps"],
        eval_steps=hyperparamters["eval_steps"],
        evaluation_strategy="steps"
    )

    with training_args.strategy.scope():
        model = TFDistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")
    
    trainer = TFTrainer(
        model=model,                         # the instantiated 🤗 Transformers model to be trained
        args=training_args,                  # training arguments, defined above
        train_dataset=train_dataset,         # training dataset
        eval_dataset=val_dataset,             # evaluation dataset
        compute_metrics=compute_metrics,
    )

    trainer.train()
    trainer.model.save_pretrained(model_dir)
    tokenizer.save_pretrained(model_dir)