in fastchat/train/train_lora_t5.py [0:0]
def train():
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments, LoraArguments)
)
(
model_args,
data_args,
training_args,
lora_args,
) = parser.parse_args_into_dataclasses()
device_map = None
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if lora_args.q_lora:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
logging.warning(
"FSDP and ZeRO3 are both currently incompatible with QLoRA."
)
compute_dtype = (
torch.float16
if training_args.fp16
else (torch.bfloat16 if training_args.bf16 else torch.float32)
)
model = transformers.AutoModelForSeq2SeqLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
device_map=device_map,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
)
if lora_args.q_lora
else None,
)
lora_config = LoraConfig(
r=lora_args.lora_r,
lora_alpha=lora_args.lora_alpha,
target_modules=lora_args.lora_target_modules,
lora_dropout=lora_args.lora_dropout,
bias=lora_args.lora_bias,
task_type=TaskType.SEQ_2_SEQ_LM,
)
if lora_args.q_lora:
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=training_args.gradient_checkpointing
)
if not ddp and torch.cuda.device_count() > 1:
# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
model.is_parallelizable = True
model.model_parallel = True
model = get_peft_model(model, lora_config)
if training_args.deepspeed is not None and training_args.local_rank == 0:
model.print_trainable_parameters()
if training_args.gradient_checkpointing:
model.enable_input_require_grads()
# Dacheng: Note we can only use T5Tokenizer, otherwise it will prepend
# a space before special tokens.
tokenizer = transformers.T5Tokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
)
smart_tokenizer_and_embedding_resize(
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
other_tokens=["<", "{", "\n", "}", "`", " ", "\\", "^", "\t"],
tokenizer=tokenizer,
model=model,
)
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
trainer = Trainer(
model=model, tokenizer=tokenizer, args=training_args, **data_module
)
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
# check if zero3 mode enabled
if deepspeed.is_deepspeed_zero3_enabled():
# use deepspeed engine internal function to gather state dict
# state_dict_zero3 contains whole parameters of base and lora adapters
# we will not extract lora parameters since peft save_pretrained will do that
# https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/peft_model.py#L125
# https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/utils/save_and_load.py#L19
state_dict_zero3 = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
if training_args.local_rank == 0:
state_dict = state_dict_zero3
else:
# in other mode we use original code from fastchat team, to make sure our change is minimum
state_dict = get_peft_state_maybe_zero_3(
model.named_parameters(), lora_args.lora_bias
)
if training_args.local_rank == 0:
safe_save_model_for_hf_trainer(
trainer=trainer, output_dir=training_args.output_dir, state_dict=state_dict
)