def _create_model_from_path()

in trl/trainer/dpo_trainer.py [0:0]


    def _create_model_from_path(self, model_path: str, args: DPOConfig, is_ref: bool = False) -> PreTrainedModel:
        """Creates a model from a path or model identifier."""
        if not is_ref:
            model_init_kwargs = args.model_init_kwargs or {}
        else:
            model_init_kwargs = args.ref_model_init_kwargs or {}

        # Handle torch dtype
        torch_dtype = model_init_kwargs.get("torch_dtype")
        if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
            pass  # torch_dtype is already a torch.dtype or "auto" or None
        elif isinstance(torch_dtype, str):  # it's a str, but not "auto"
            torch_dtype = getattr(torch, torch_dtype)
            model_init_kwargs["torch_dtype"] = torch_dtype
        else:
            raise ValueError(
                "Invalid `torch_dtype` passed to `DPOConfig`. Expected either 'auto' or a string representing "
                f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
            )
        # Disable caching if gradient checkpointing is enabled (not supported)
        # if args.gradient_checkpointing:
        #     model_init_kwargs["use_cache"] = False

        # Create model
        model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
        return model