def _get_layer_device_map_for_cache_init()

in src/transformers/generation/utils.py [0:0]


    def _get_layer_device_map_for_cache_init(self) -> Optional[dict[int, Union[str, int]]]:
        """
        Returns the device map for each decoder layer, to allocate the cache on the right device.
        Inspired from `dispatch_model` in accelerate.
        """
        execution_device_map = None

        if hasattr(self, "hf_device_map"):
            if set(self.hf_device_map.values()) == {"cpu"} or set(self.hf_device_map.values()) == {"cpu", "disk"}:
                main_device = "cpu"
            else:
                main_device = [d for d in self.hf_device_map.values() if d not in ["cpu", "disk"]][0]
            execution_device_map = {
                name: main_device if device in ["cpu", "disk"] else device
                for name, device in self.hf_device_map.items()
            }

        # No `execution_device_map` -> rely on `self.device` to allocate the cache
        if execution_device_map is None:
            return None

        # Single device for all layers
        num_hidden_layers = self.config.get_text_config().num_hidden_layers
        if len(execution_device_map) == 1 and "" in execution_device_map:
            return dict.fromkeys(range(num_hidden_layers), execution_device_map[""])

        # Multiple devices in `execution_device_map` -> we need to map decoder layers to the correct device.
        layer_device_map = {}
        # Case 1: The model has a `get_decoder` method, we can use it to find the decoder name.
        if hasattr(self, "get_decoder"):
            decoder_name = None
            for name, module in self.named_modules():
                if module is self.get_decoder():
                    decoder_name = name
                    break
            if decoder_name is None:
                raise RuntimeError(
                    "`model.get_decoder()` is not returning a named module of the model. This is unexpected, please "
                    "open an issue on GitHub."
                )

            decoder_mapped_modules = [
                module_name for module_name in execution_device_map.keys() if decoder_name in module_name
            ]
            # The decoder name may be present in `execution_device_map` in two forms:
            # a) each layer has a device mapping
            if len(decoder_mapped_modules) >= num_hidden_layers:
                for idx in range(num_hidden_layers):
                    for module_name in decoder_mapped_modules:
                        if f".{idx}." in f"{module_name}.":
                            layer_device_map[idx] = execution_device_map[module_name]
                            break

            # b) the whole module is mapped to a single device. If the decoder name is NOT present in the device map,
            # then the mapping is done in a parent module
            else:
                while True:
                    if decoder_name in execution_device_map:
                        layer_device_map = dict.fromkeys(range(num_hidden_layers), execution_device_map[decoder_name])
                        break
                    elif "." in decoder_name:
                        decoder_name = decoder_name.rsplit(".", 1)[0]  # gets the name of the parent module
                    else:
                        raise RuntimeError(f"Decoder name {decoder_name} not found in execution device map")

        # Case 2: Legacy code path: assume the decoder layers are named as `(...).X` (X being the layer index)
        else:
            for layer in execution_device_map:
                for idx in range(num_hidden_layers):
                    if f".{idx}." in f"{layer}.":
                        layer_device_map[idx] = execution_device_map[layer]
                        break

        for idx in range(num_hidden_layers):
            if idx not in layer_device_map:
                raise RuntimeError(f"layer {idx} has not been mapped to a device.")
        return layer_device_map