in src/sal/models/skywork_o1_prm/modeling_base.py [0:0]
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r"""
Instantiates a new model from a pretrained model from `transformers`. The
pretrained model is loaded using the `from_pretrained` method of the
`transformers.PreTrainedModel` class. The arguments that are specific to the
`transformers.PreTrainedModel` class are passed along this method and filtered
out from the `kwargs` argument.
Args:
pretrained_model_name_or_path (`str` or `transformers.PreTrainedModel`):
The path to the pretrained model or its name.
*model_args (`list`, *optional*)):
Additional positional arguments passed along to the underlying model's
`from_pretrained` method.
**kwargs (`dict`, *optional*):
Additional keyword arguments passed along to the underlying model's
`from_pretrained` method. We also pre-process the kwargs to extract
the arguments that are specific to the `transformers.PreTrainedModel`
class and the arguments that are specific to trl models. The kwargs
also support `prepare_model_for_kbit_training` arguments from
`peft` library.
"""
if kwargs is not None:
peft_config = kwargs.pop("peft_config", None)
reward_adapter = kwargs.pop("reward_adapter", None)
reward_adapter_name = kwargs.pop("reward_adapter_name", "reward_adapter")
is_trainable = kwargs.pop("is_trainable", False)
trl_model_args, pretrained_kwargs, peft_quantization_kwargs = (
cls._split_kwargs(kwargs)
)
token = pretrained_kwargs.get("token", None)
else:
peft_config = None
is_trainable = False
trl_model_args = {}
pretrained_kwargs = {}
peft_quantization_kwargs = {}
token = None
if reward_adapter is not None and not isinstance(reward_adapter, str):
raise ValueError(
"The `reward_adapter` argument should be a string representing the name of local path or the Hub id to the Reward Modeling adapter."
)
is_peft_model = False
current_device = cls._get_current_device()
if isinstance(pretrained_model_name_or_path, str):
is_loaded_in_8bit = (
pretrained_kwargs["load_in_8bit"]
if "load_in_8bit" in pretrained_kwargs
else False
)
is_loaded_in_4bit = (
pretrained_kwargs["load_in_4bit"]
if "load_in_4bit" in pretrained_kwargs
else False
)
else:
is_loaded_in_8bit = getattr(
pretrained_model_name_or_path, "is_loaded_in_8bit", False
)
is_loaded_in_4bit = getattr(
pretrained_model_name_or_path, "is_loaded_in_4bit", False
)
if (
is_loaded_in_8bit or is_loaded_in_4bit
) and "device_map" not in pretrained_kwargs:
# warn users
logging.warning(
"The `device_map` argument is not provided. We will override the device_map argument."
" to set the entire"
" model on the current device. If you want to set the model on multiple devices, please provide"
" a custom `device_map` argument."
)
pretrained_kwargs["device_map"] = {"": current_device}
# First, load the pre-trained model using the parent-class
# either `AutoModelForCausalLM` or `AutoModelForSeq2SeqLM`
if isinstance(pretrained_model_name_or_path, str):
remote_adapter_config = None
local_adapter_present = os.path.exists(
os.path.join(pretrained_model_name_or_path, "adapter_config.json")
)
pretrained_model = cls.transformers_parent_class.from_pretrained(
pretrained_model_name_or_path, *model_args, **pretrained_kwargs
)
elif isinstance(
pretrained_model_name_or_path, cls.supported_pretrained_model_architectures
):
pretrained_model = pretrained_model_name_or_path
else:
raise ValueError(
"pretrained_model_name_or_path should be a string or a PreTrainedModel, "
f"but is {type(pretrained_model_name_or_path)}"
)
# Add reward modeling adapter if specified
if not is_peft_model and reward_adapter is not None:
raise ValueError("reward_adapter can only be used with a PeftModel. ")
elif is_peft_model and reward_adapter is not None:
score_module = cls.add_and_load_reward_modeling_adapter(
pretrained_model, reward_adapter, reward_adapter_name, token=token
)
multi_adapter_args = {
"score_module": score_module,
"supports_rm_adapter": True,
"rm_adapter_name": reward_adapter_name,
}
else:
multi_adapter_args = {"supports_rm_adapter": False}
# Then, create the full model by instantiating the wrapper class
model = cls(pretrained_model, **multi_adapter_args, **trl_model_args)
# if resume_training, load the state_dict again - this is ok since the
# state_dict is removed from the model after loading it.
is_resuming_training = True
if isinstance(pretrained_model_name_or_path, str):
safe_filename = os.path.join(
pretrained_model_name_or_path, "model.safetensors"
)
filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
sharded_index_filename = os.path.join(
pretrained_model_name_or_path, "pytorch_model.bin.index.json"
)
safe_sharded_index_filename = os.path.join(
pretrained_model_name_or_path, "model.safetensors.index.json"
)
is_sharded = False
use_safe = os.path.exists(safe_filename)
if not (os.path.exists(filename) or os.path.exists(safe_filename)):
# Try with `pytorch_model.bin`
filename, files_to_download, is_sharded, is_resuming_training = (
cls._get_checkpoint_from_hub(
pretrained_model,
pretrained_model_name_or_path,
sharded_index_filename,
token=token,
)
)
# Try with safetensors
if filename is None and files_to_download is None:
(
safe_filename,
files_to_download,
is_sharded,
is_resuming_training,
) = cls._get_checkpoint_from_hub(
pretrained_model,
pretrained_model_name_or_path,
safe_sharded_index_filename,
token=token,
model_name="model.safetensors",
model_index_name="model.safetensors.index.json",
)
use_safe = True
else:
use_safe = False
loading_func = safe_load_file if use_safe else torch.load
load_kwargs = {} if use_safe else {"map_location": "cpu"}
if is_resuming_training:
if is_sharded:
# download each file and add it to the state_dict
state_dict = {}
for shard_file in files_to_download:
filename = hf_hub_download(
pretrained_model_name_or_path,
shard_file,
token=token,
)
state_dict.update(loading_func(filename, **load_kwargs))
else:
state_dict = loading_func(
filename if not use_safe else safe_filename, **load_kwargs
)
else:
state_dict = pretrained_model_name_or_path.state_dict()
model.is_peft_model = is_peft_model
model.current_device = current_device
if is_resuming_training:
model.post_init(state_dict=state_dict)
return model