in src/peft/tuners/tuners_utils.py [0:0]
def _move_adapter_to_device_of_base_layer(self, adapter_name: str, device: Optional[torch.device] = None) -> None:
"""
Move the adapter of the given name to the device of the base layer.
"""
if device is None:
base_layer = self.get_base_layer()
if isinstance(base_layer, nn.MultiheadAttention):
base_layer = base_layer.out_proj
# check weight and qweight (for GPTQ)
for weight_name in ("weight", "qweight"):
weight = getattr(base_layer, weight_name, None)
if weight is not None:
device = weight.device
dtype = weight.dtype
break
else:
# no break encountered: could not determine the device
return
meta = torch.device("meta")
# loop through all potential adapter layers and move them to the device of the base layer; be careful to only
# move this specific adapter to the device, as the other adapters could be on different devices
# see #1639
for adapter_layer_name in self.adapter_layer_names + self.other_param_names:
adapter_layer = getattr(self, adapter_layer_name, None)
if not isinstance(adapter_layer, (nn.ModuleDict, nn.ParameterDict, BufferDict)):
continue
if adapter_name not in adapter_layer:
continue
if any(p.device == meta for p in adapter_layer.parameters()):
continue
if weight.dtype.is_floating_point or weight.dtype.is_complex:
adapter_layer[adapter_name] = adapter_layer[adapter_name].to(device, dtype=dtype)
else:
adapter_layer[adapter_name] = adapter_layer[adapter_name].to(device)