in notebooks/src/code/train.py [0:0]
def train(model_args, data_args, training_args):
logger.info("Creating config and model")
_, model, tokenizer = get_model(model_args, data_args)
# Tokenizer check: this script requires a fast tokenizer.
if not isinstance(tokenizer, PreTrainedTokenizerFast):
raise ValueError(
"This example script only works for models that have a fast tokenizer. See the list "
"at https://huggingface.co/transformers/index.html#supported-frameworks for details."
)
logger.info("Loading datasets")
datasets = data.get_datasets(data_args, tokenizer)
if datasets.train_dataset:
logger.info(f"train dataset has {len(datasets.train_dataset)} samples")
else:
logger.info("No training dataset provided")
if datasets.eval_dataset:
logger.info(f"validation dataset has {len(datasets.eval_dataset)} samples")
else:
logger.info("No validation dataset provided")
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
logger.warning("No previous checkpoint found: training from scratch")
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this "
"behavior, create the training job with an empty `checkpoint_s3_uri` or none."
)
logger.info("Setting up trainer")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=datasets.train_dataset,
eval_dataset=datasets.eval_dataset if data_args.validation else None,
tokenizer=tokenizer,
data_collator=datasets.data_collator,
callbacks=[
EarlyStoppingCallback(
early_stopping_patience=training_args.early_stopping_patience,
early_stopping_threshold=training_args.early_stopping_threshold,
)
]
if (
training_args.early_stopping_patience is not None
or training_args.early_stopping_threshold is not None
)
else [],
compute_metrics=datasets.metric_computer,
)
if not training_args.do_train:
logger.warning(f"Training skipped (args.do_train={training_args.do_train})")
else:
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
trainer.save_model() # (Saves the tokenizer too)
max_train_samples = (
data_args.max_train_samples
if data_args.max_train_samples is not None
else len(datasets.train_dataset)
)
metrics["train_samples"] = min(max_train_samples, len(datasets.train_dataset))
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
logger.info(f"Saving model to {training_args.model_dir}")
trainer.save_model(training_args.model_dir) # (Saves the tokenizer too)
# To enable directly deploying this model via SageMaker SDK's Estimator.deploy() (rather than
# needing to create a PyTorchModel with entry_point / source_dir args), we need to save any
# inference handler function code to model_dir/code. Here we compromise efficiency to the
# benefit of usage simplicity, by just copying the contents of this training code folder to the
# model/code folder for inference:
code_path = os.path.join(training_args.model_dir, "code")
logger.info(f"Copying code to {code_path} for inference")
for currpath, _, files in os.walk("."):
for file in files:
# Skip any filenames starting with dot:
if file.startswith("."):
continue
filepath = os.path.join(currpath, file)
# Skip any pycache or dot folders:
if ((os.path.sep + ".") in filepath) or ("__pycache__" in filepath):
continue
relpath = filepath[len(".") :]
if relpath.startswith(os.path.sep):
relpath = relpath[1:]
outpath = os.path.join(code_path, relpath)
logger.info(f"Copying {filepath} to {outpath}")
os.makedirs(outpath.rpartition(os.path.sep)[0], exist_ok=True)
shutil.copy2(filepath, outpath)
return trainer