in src/lighteval/utils/parallelism.py [0:0]
def test_all_gather(accelerator=None, parallel_context=None):
"""
Test the gather operation in a parallel setup.
Args:
accelerator (Optional): The accelerator object used for parallelism.
parallel_context (Optional): The parallel context object used for parallelism.
Raises:
ImportError: If the required accelerator or parallel context is not available.
"""
if accelerator:
if not is_accelerate_available():
raise ImportError(NO_ACCELERATE_ERROR_MSG)
logger.info("Test gather tensor")
test_tensor: torch.Tensor = torch.tensor([accelerator.process_index], device=accelerator.device)
gathered_tensor: torch.Tensor = accelerator.gather(test_tensor)
logger.info(f"gathered_tensor {gathered_tensor}, should be {list(range(accelerator.num_processes))}")
accelerator.wait_for_everyone()
elif parallel_context:
if not is_nanotron_available():
raise ImportError(NO_NANOTRON_ERROR_MSG)
from nanotron import distributed as dist
from nanotron import logging
logger.info("Test gather tensor")
# Do a first NCCL sync to warmup and try to avoid Timeout after model/data loading
logging.log_rank(
f"[TEST] Running NCCL sync for ranks {list(range(parallel_context.world_pg.size()))}",
logger=logger,
level=logging.WARNING,
group=parallel_context.dp_pg,
rank=0,
)
test_tensor = torch.tensor([dist.get_rank(parallel_context.world_pg)], device=torch.device("cuda"))
test_tensor_list = [torch.zeros_like(test_tensor) for _ in range(parallel_context.world_pg.size())]
dist.all_gather(test_tensor_list, test_tensor, group=parallel_context.world_pg, async_op=False)
dist.barrier()
logging.log_rank(
f"[TEST] NCCL sync for ranks {[t.item() for t in test_tensor_list]}",
logger=logger,
level=logging.WARNING,
group=parallel_context.dp_pg,
rank=0,
)
del test_tensor_list
del test_tensor
else:
logger.info("Not running in a parallel setup, nothing to test")