in src/peft/tuners/tuners_utils.py [0:0]
def _cast_adapter_dtype(self, adapter_name: str, autocast_adapter_dtype: bool = True) -> None:
"""
A helper method to cast the adapter weights to the correct dtype.
Currently, this only upcasts float16 and bfloat16 to float32.
Args:
adapter_name (`str`):
The adapter name.
autocast_adapter_dtype (`bool`, *optional*):
Whether to autocast the adapter dtype. Defaults to `True`.
"""
if not autocast_adapter_dtype:
return
dtypes_to_convert_to_fp32 = {torch.float16, torch.bfloat16}
for module in self.model.modules():
if not isinstance(module, BaseTunerLayer):
continue
for submodule in module.modules():
if not isinstance(submodule, (nn.ModuleDict, nn.ParameterDict, BufferDict)):
continue
if adapter_name not in submodule:
continue
if isinstance(submodule[adapter_name], nn.Parameter):
if submodule[adapter_name].dtype in dtypes_to_convert_to_fp32:
submodule[adapter_name].data = submodule[adapter_name].data.to(torch.float32)
continue
if isinstance(submodule[adapter_name], torch.Tensor): # e.g. from a BufferDict
if submodule[adapter_name].dtype in dtypes_to_convert_to_fp32:
submodule[adapter_name] = submodule[adapter_name].to(torch.float32)
continue
for param in submodule[adapter_name].parameters():
if param.dtype in dtypes_to_convert_to_fp32:
param.data = param.data.to(torch.float32)