in trl/trainer/dpo_trainer.py [0:0]
def _create_model_from_path(self, model_path: str, args: DPOConfig, is_ref: bool = False) -> PreTrainedModel:
"""Creates a model from a path or model identifier."""
if not is_ref:
model_init_kwargs = args.model_init_kwargs or {}
else:
model_init_kwargs = args.ref_model_init_kwargs or {}
# Handle torch dtype
torch_dtype = model_init_kwargs.get("torch_dtype")
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
pass # torch_dtype is already a torch.dtype or "auto" or None
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
torch_dtype = getattr(torch, torch_dtype)
model_init_kwargs["torch_dtype"] = torch_dtype
else:
raise ValueError(
"Invalid `torch_dtype` passed to `DPOConfig`. Expected either 'auto' or a string representing "
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
)
# Disable caching if gradient checkpointing is enabled (not supported)
# if args.gradient_checkpointing:
# model_init_kwargs["use_cache"] = False
# Create model
model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
return model