in src/autotrain/trainers/vlm/train_vlm_generic.py [0:0]
def train(config):
valid_data = None
if config.data_path == f"{config.project_name}/autotrain-data":
train_data = load_from_disk(config.data_path)[config.train_split]
else:
if ":" in config.train_split:
dataset_config_name, split = config.train_split.split(":")
train_data = load_dataset(
config.data_path,
name=dataset_config_name,
split=split,
token=config.token,
)
else:
train_data = load_dataset(
config.data_path,
split=config.train_split,
token=config.token,
)
if config.valid_split is not None:
if config.data_path == f"{config.project_name}/autotrain-data":
valid_data = load_from_disk(config.data_path)[config.valid_split]
else:
if ":" in config.valid_split:
dataset_config_name, split = config.valid_split.split(":")
valid_data = load_dataset(
config.data_path,
name=dataset_config_name,
split=split,
token=config.token,
)
else:
valid_data = load_dataset(
config.data_path,
split=config.valid_split,
token=config.token,
)
logger.info(f"Train data: {train_data}")
logger.info(f"Valid data: {valid_data}")
if config.trainer == "captioning":
config.prompt_text_column = "caption"
processor = AutoProcessor.from_pretrained(config.model, token=config.token, trust_remote_code=ALLOW_REMOTE_CODE)
logging_steps = utils.configure_logging_steps(config, train_data, valid_data)
training_args = utils.configure_training_args(config, logging_steps)
args = TrainingArguments(**training_args)
model = utils.get_model(config)
logger.info("creating trainer")
callbacks = utils.get_callbacks(config)
trainer_args = dict(
args=args,
model=model,
callbacks=callbacks,
)
col_fn = partial(collate_fn, config=config, processor=processor)
trainer = Trainer(
**trainer_args,
train_dataset=train_data,
eval_dataset=valid_data if valid_data is not None else None,
data_collator=col_fn,
)
trainer.remove_callback(PrinterCallback)
trainer.train()
utils.post_training_steps(config, trainer)