in src/transformers/generation/utils.py [0:0]
def _get_layer_device_map_for_cache_init(self) -> Optional[dict[int, Union[str, int]]]:
"""
Returns the device map for each decoder layer, to allocate the cache on the right device.
Inspired from `dispatch_model` in accelerate.
"""
execution_device_map = None
if hasattr(self, "hf_device_map"):
if set(self.hf_device_map.values()) == {"cpu"} or set(self.hf_device_map.values()) == {"cpu", "disk"}:
main_device = "cpu"
else:
main_device = [d for d in self.hf_device_map.values() if d not in ["cpu", "disk"]][0]
execution_device_map = {
name: main_device if device in ["cpu", "disk"] else device
for name, device in self.hf_device_map.items()
}
# No `execution_device_map` -> rely on `self.device` to allocate the cache
if execution_device_map is None:
return None
# Single device for all layers
num_hidden_layers = self.config.get_text_config().num_hidden_layers
if len(execution_device_map) == 1 and "" in execution_device_map:
return dict.fromkeys(range(num_hidden_layers), execution_device_map[""])
# Multiple devices in `execution_device_map` -> we need to map decoder layers to the correct device.
layer_device_map = {}
# Case 1: The model has a `get_decoder` method, we can use it to find the decoder name.
if hasattr(self, "get_decoder"):
decoder_name = None
for name, module in self.named_modules():
if module is self.get_decoder():
decoder_name = name
break
if decoder_name is None:
raise RuntimeError(
"`model.get_decoder()` is not returning a named module of the model. This is unexpected, please "
"open an issue on GitHub."
)
decoder_mapped_modules = [
module_name for module_name in execution_device_map.keys() if decoder_name in module_name
]
# The decoder name may be present in `execution_device_map` in two forms:
# a) each layer has a device mapping
if len(decoder_mapped_modules) >= num_hidden_layers:
for idx in range(num_hidden_layers):
for module_name in decoder_mapped_modules:
if f".{idx}." in f"{module_name}.":
layer_device_map[idx] = execution_device_map[module_name]
break
# b) the whole module is mapped to a single device. If the decoder name is NOT present in the device map,
# then the mapping is done in a parent module
else:
while True:
if decoder_name in execution_device_map:
layer_device_map = dict.fromkeys(range(num_hidden_layers), execution_device_map[decoder_name])
break
elif "." in decoder_name:
decoder_name = decoder_name.rsplit(".", 1)[0] # gets the name of the parent module
else:
raise RuntimeError(f"Decoder name {decoder_name} not found in execution device map")
# Case 2: Legacy code path: assume the decoder layers are named as `(...).X` (X being the layer index)
else:
for layer in execution_device_map:
for idx in range(num_hidden_layers):
if f".{idx}." in f"{layer}.":
layer_device_map[idx] = execution_device_map[layer]
break
for idx in range(num_hidden_layers):
if idx not in layer_device_map:
raise RuntimeError(f"layer {idx} has not been mapped to a device.")
return layer_device_map