in fastchat/train/train_lora.py [0:0]
def train():
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments, LoraArguments)
)
(
model_args,
data_args,
training_args,
lora_args,
) = parser.parse_args_into_dataclasses()
if training_args.flash_attn:
replace_llama_attn_with_flash_attn()
device_map = None
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if lora_args.q_lora:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
logging.warning(
"FSDP and ZeRO3 are both currently incompatible with QLoRA."
)
compute_dtype = (
torch.float16
if training_args.fp16
else (torch.bfloat16 if training_args.bf16 else torch.float32)
)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
device_map=device_map,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
)
if lora_args.q_lora
else None,
)
lora_config = LoraConfig(
r=lora_args.lora_r,
lora_alpha=lora_args.lora_alpha,
target_modules=lora_args.lora_target_modules,
lora_dropout=lora_args.lora_dropout,
bias=lora_args.lora_bias,
task_type="CAUSAL_LM",
)
if lora_args.q_lora:
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=training_args.gradient_checkpointing
)
if not ddp and torch.cuda.device_count() > 1:
# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
model.is_parallelizable = True
model.model_parallel = True
model = get_peft_model(model, lora_config)
if training_args.flash_attn:
for name, module in model.named_modules():
if "norm" in name:
module = module.to(compute_dtype)
if "lm_head" in name or "embed_tokens" in name:
if hasattr(module, "weight"):
module = module.to(compute_dtype)
if training_args.deepspeed is not None and training_args.local_rank == 0:
model.print_trainable_parameters()
if training_args.gradient_checkpointing:
model.enable_input_require_grads()
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
)
tokenizer.pad_token = tokenizer.unk_token
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
trainer = Trainer(
model=model, tokenizer=tokenizer, args=training_args, **data_module
)
model.config.use_cache = False
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
# check if zero3 mode enabled
if deepspeed.is_deepspeed_zero3_enabled():
# use deepspeed engine internal function to gather state dict
# state_dict_zero3 contains whole parameters of base and lora adapters
# we will not extract lora parameters since peft save_pretrained will do that
# https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/peft_model.py#L125
# https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/utils/save_and_load.py#L19
state_dict_zero3 = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
if training_args.local_rank == 0:
state_dict = state_dict_zero3
else:
# in other mode we use original code from fastchat team, to make sure our change is minimum
state_dict = get_peft_state_maybe_zero_3(
model.named_parameters(), lora_args.lora_bias
)
if training_args.local_rank == 0:
model.save_pretrained(training_args.output_dir, state_dict=state_dict)