in method_comparison/MetaMathQA/run.py [0:0]
def main(*, path_experiment: str, experiment_name: str, clean: bool) -> None:
tic_total = time.perf_counter()
start_date = dt.datetime.now(tz=dt.timezone.utc).replace(microsecond=0).isoformat()
peft_branch = get_peft_branch()
if peft_branch == "main":
print_verbose("===== This experiment is categorized as a MAIN run because the PEFT branch is 'main' ======")
else:
print_verbose(
f"===== This experiment is categorized as a TEST run because the PEFT branch is '{peft_branch}' ======"
)
# load configs
peft_config: Optional[PeftConfig] = None
if os.path.exists(os.path.join(path_experiment, CONFIG_NAME)):
peft_config = PeftConfig.from_pretrained(path_experiment)
else:
print_verbose(f"Could not find PEFT config at {path_experiment}, performing FULL FINETUNING")
path_train_config = os.path.join(path_experiment, FILE_NAME_TRAIN_PARAMS)
train_config = get_train_config(path_train_config)
set_seed(train_config.seed)
# initialize objects
cuda_memory_init = init_cuda()
tokenizer = get_tokenizer(model_id=train_config.model_id, max_seq_length=train_config.max_seq_length)
model_info = get_base_model_info(train_config.model_id)
metamath_info = get_dataset_info("meta-math/MetaMathQA")
gsm8k_info = get_dataset_info("openai/gsm8k")
model = get_model(
model_id=train_config.model_id,
dtype=train_config.dtype,
compile=train_config.compile,
attn_implementation=train_config.attn_implementation,
peft_config=peft_config,
autocast_adapter_dtype=train_config.autocast_adapter_dtype,
)
print_verbose(model)
# train model
train_result = train(
model=model,
max_steps=train_config.max_steps,
batch_size=train_config.batch_size,
batch_size_eval=train_config.batch_size_eval,
tokenizer=tokenizer,
cuda_memory_init=cuda_memory_init,
eval_steps=train_config.eval_steps,
generation_kwargs=train_config.generation_kwargs,
grad_norm_clip=train_config.grad_norm_clip,
optimizer_type=train_config.optimizer_type,
optimizer_kwargs=train_config.optimizer_kwargs,
query_template=train_config.query_template,
lr_scheduler_arg=train_config.lr_scheduler,
use_amp=train_config.use_amp,
is_adalora=isinstance(peft_config, AdaLoraConfig),
)
if train_result.status == TrainStatus.FAILED:
print_verbose("Training failed, not logging results")
sys.exit(1)
file_size = get_file_size(
model,
peft_config=peft_config,
clean=clean,
print_fn=print_verbose,
)
time_total = time.perf_counter() - tic_total
# log results: print and save to file
log_results(
experiment_name=experiment_name,
train_result=train_result,
cuda_memory_init=cuda_memory_init,
time_total=time_total,
file_size=file_size,
model_info=model_info,
datasets_info={"metamath": metamath_info, "gsm8k": gsm8k_info},
start_date=start_date,
train_config=train_config,
peft_config=peft_config,
print_fn=print_verbose,
)