in src/autotrain/trainers/extractive_question_answering/__main__.py [0:0]
def train(config):
if isinstance(config, dict):
config = ExtractiveQuestionAnsweringParams(**config)
train_data = None
valid_data = None
# check if config.train_split.csv exists in config.data_path
if config.train_split is not None:
if config.data_path == f"{config.project_name}/autotrain-data":
logger.info("loading dataset from disk")
train_data = load_from_disk(config.data_path)[config.train_split]
else:
if ":" in config.train_split:
dataset_config_name, split = config.train_split.split(":")
train_data = load_dataset(
config.data_path,
name=dataset_config_name,
split=split,
token=config.token,
trust_remote_code=ALLOW_REMOTE_CODE,
)
else:
train_data = load_dataset(
config.data_path,
split=config.train_split,
token=config.token,
trust_remote_code=ALLOW_REMOTE_CODE,
)
if config.valid_split is not None:
if config.data_path == f"{config.project_name}/autotrain-data":
logger.info("loading dataset from disk")
valid_data = load_from_disk(config.data_path)[config.valid_split]
else:
if ":" in config.valid_split:
dataset_config_name, split = config.valid_split.split(":")
valid_data = load_dataset(
config.data_path,
name=dataset_config_name,
split=split,
token=config.token,
trust_remote_code=ALLOW_REMOTE_CODE,
)
else:
valid_data = load_dataset(
config.data_path,
split=config.valid_split,
token=config.token,
trust_remote_code=ALLOW_REMOTE_CODE,
)
logger.info(train_data)
if config.valid_split is not None:
logger.info(valid_data)
model_config = AutoConfig.from_pretrained(config.model, allow_remote_code=ALLOW_REMOTE_CODE, token=config.token)
try:
model = AutoModelForQuestionAnswering.from_pretrained(
config.model,
config=model_config,
trust_remote_code=ALLOW_REMOTE_CODE,
token=config.token,
ignore_mismatched_sizes=True,
)
except OSError:
model = AutoModelForQuestionAnswering.from_pretrained(
config.model,
config=model_config,
from_tf=True,
trust_remote_code=ALLOW_REMOTE_CODE,
token=config.token,
ignore_mismatched_sizes=True,
)
tokenizer = AutoTokenizer.from_pretrained(config.model, token=config.token, trust_remote_code=ALLOW_REMOTE_CODE)
use_v2 = False
if config.valid_split is not None:
id_column = list(range(len(valid_data)))
for data in valid_data:
if -1 in data[config.answer_column]["answer_start"]:
use_v2 = True
break
valid_data = valid_data.add_column("id", id_column)
column_names = valid_data.column_names
partial_process = partial(
utils.prepare_qa_validation_features,
tokenizer=tokenizer,
config=config,
)
processed_eval_dataset = valid_data.map(
partial_process,
batched=True,
remove_columns=column_names,
num_proc=2,
desc="Running tokenizer on validation dataset",
)
orig_valid_data = copy.deepcopy(valid_data)
train_data = ExtractiveQuestionAnsweringDataset(data=train_data, tokenizer=tokenizer, config=config)
if config.valid_split is not None:
valid_data = ExtractiveQuestionAnsweringDataset(data=valid_data, tokenizer=tokenizer, config=config)
if config.logging_steps == -1:
if config.valid_split is not None:
logging_steps = int(0.2 * len(valid_data) / config.batch_size)
else:
logging_steps = int(0.2 * len(train_data) / config.batch_size)
if logging_steps == 0:
logging_steps = 1
if logging_steps > 25:
logging_steps = 25
config.logging_steps = logging_steps
else:
logging_steps = config.logging_steps
logger.info(f"Logging steps: {logging_steps}")
training_args = dict(
output_dir=config.project_name,
per_device_train_batch_size=config.batch_size,
per_device_eval_batch_size=2 * config.batch_size,
learning_rate=config.lr,
num_train_epochs=config.epochs,
eval_strategy=config.eval_strategy if config.valid_split is not None else "no",
logging_steps=logging_steps,
save_total_limit=config.save_total_limit,
save_strategy=config.eval_strategy if config.valid_split is not None else "no",
gradient_accumulation_steps=config.gradient_accumulation,
report_to=config.log,
auto_find_batch_size=config.auto_find_batch_size,
lr_scheduler_type=config.scheduler,
optim=config.optimizer,
warmup_ratio=config.warmup_ratio,
weight_decay=config.weight_decay,
max_grad_norm=config.max_grad_norm,
push_to_hub=False,
load_best_model_at_end=True if config.valid_split is not None else False,
ddp_find_unused_parameters=False,
)
if config.mixed_precision == "fp16":
training_args["fp16"] = True
if config.mixed_precision == "bf16":
training_args["bf16"] = True
if config.valid_split is not None:
early_stop = EarlyStoppingCallback(
early_stopping_patience=config.early_stopping_patience,
early_stopping_threshold=config.early_stopping_threshold,
)
callbacks_to_use = [early_stop]
else:
callbacks_to_use = []
callbacks_to_use.extend([UploadLogs(config=config), LossLoggingCallback(), TrainStartCallback()])
if config.valid_split is not None:
logger.info(processed_eval_dataset)
compute_metrics = partial(
utils.compute_metrics,
eval_dataset=processed_eval_dataset,
eval_examples=orig_valid_data,
config=config,
use_v2=use_v2,
)
else:
compute_metrics = None
args = TrainingArguments(**training_args)
trainer_args = dict(
args=args,
model=model,
callbacks=callbacks_to_use,
compute_metrics=compute_metrics,
)
trainer = Trainer(
**trainer_args,
train_dataset=train_data,
eval_dataset=valid_data,
)
trainer.remove_callback(PrinterCallback)
trainer.train()
logger.info("Finished training, saving model...")
trainer.save_model(config.project_name)
tokenizer.save_pretrained(config.project_name)
model_card = utils.create_model_card(config, trainer)
# save model card to output directory as README.md
with open(f"{config.project_name}/README.md", "w") as f:
f.write(model_card)
if config.push_to_hub:
if PartialState().process_index == 0:
remove_autotrain_data(config)
save_training_params(config)
logger.info("Pushing model to hub...")
api = HfApi(token=config.token)
api.create_repo(
repo_id=f"{config.username}/{config.project_name}", repo_type="model", private=True, exist_ok=True
)
api.upload_folder(
folder_path=config.project_name,
repo_id=f"{config.username}/{config.project_name}",
repo_type="model",
)
if PartialState().process_index == 0:
pause_space(config)