in src/accelerate/utils/modeling.py [0:0]
def clean_device_map(device_map: dict[str, Union[int, str, torch.device]], module_name: str = ""):
"""
Cleans a device_map by grouping all submodules that go on the same device together.
"""
# Get the value of the current module and if there is only one split across several keys, regroup it.
prefix = "" if module_name == "" else f"{module_name}."
values = [v for k, v in device_map.items() if k.startswith(prefix)]
if len(set(values)) == 1 and len(values) > 1:
for k in [k for k in device_map if k.startswith(prefix)]:
del device_map[k]
device_map[module_name] = values[0]
# Recurse over the children
children_modules = [k for k in device_map.keys() if k.startswith(prefix) and len(k) > len(module_name)]
idx = len(module_name.split(".")) + 1 if len(module_name) > 0 else 1
children_modules = set(".".join(k.split(".")[:idx]) for k in children_modules)
for child in children_modules:
clean_device_map(device_map, module_name=child)
return device_map