in src/lighteval/models/transformers/vlm_transformers_model.py [0:0]
def init_model_parallel(self, model_parallel: bool | None = None) -> Tuple[bool, Optional[dict], Optional[str]]:
"""Compute all the parameters related to model_parallel"""
if not is_accelerate_available():
return False, None, None
self.num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
self.num_machines = torch.cuda.device_count() // self.num_local_processes
if self.num_machines == 1:
logger.info("We are not in a distributed setting. Setting model_parallel to False.")
model_parallel = False
if model_parallel is None:
max_memory_all_gpus = get_max_memory() # A dict of the max memory for all the gpus
if "cpu" in max_memory_all_gpus:
del max_memory_all_gpus["cpu"]
model_parallel = bool(self.num_local_processes < len(max_memory_all_gpus))
logger.info(
f"Setting model parallel to {model_parallel} since "
f"the number of local processes is {self.num_local_processes} "
f"and the number of GPUs is {len(max_memory_all_gpus)}"
)
if model_parallel is True:
max_memory_all_gpus = get_max_memory() # A dict of the max memory for all the gpus
if "cpu" in max_memory_all_gpus:
del max_memory_all_gpus["cpu"]
max_mem_this_process = {
k: v
for k, v in max_memory_all_gpus.items()
if k % self.num_local_processes == (self.accelerator.process_index % self.num_local_processes)
}
device_map = "auto"
logger.info(
f"Model parallel was set to True, setting max memory per GPU to {max_mem_this_process} and device map to {device_map}"
)
else:
max_mem_this_process = None
device_map = None
logger.info(
f"Model parallel was set to False, max memory set to {max_mem_this_process} and device map to {device_map}"
)
return model_parallel, max_mem_this_process, device_map