def test_all_gather()

in src/lighteval/utils/parallelism.py [0:0]


def test_all_gather(accelerator=None, parallel_context=None):
    """
    Test the gather operation in a parallel setup.

    Args:
        accelerator (Optional): The accelerator object used for parallelism.
        parallel_context (Optional): The parallel context object used for parallelism.

    Raises:
        ImportError: If the required accelerator or parallel context is not available.
    """
    if accelerator:
        if not is_accelerate_available():
            raise ImportError(NO_ACCELERATE_ERROR_MSG)
        logger.info("Test gather tensor")
        test_tensor: torch.Tensor = torch.tensor([accelerator.process_index], device=accelerator.device)
        gathered_tensor: torch.Tensor = accelerator.gather(test_tensor)
        logger.info(f"gathered_tensor {gathered_tensor}, should be {list(range(accelerator.num_processes))}")
        accelerator.wait_for_everyone()
    elif parallel_context:
        if not is_nanotron_available():
            raise ImportError(NO_NANOTRON_ERROR_MSG)
        from nanotron import distributed as dist
        from nanotron import logging

        logger.info("Test gather tensor")
        # Do a first NCCL sync to warmup and try to avoid Timeout after model/data loading
        logging.log_rank(
            f"[TEST] Running NCCL sync for ranks {list(range(parallel_context.world_pg.size()))}",
            logger=logger,
            level=logging.WARNING,
            group=parallel_context.dp_pg,
            rank=0,
        )
        test_tensor = torch.tensor([dist.get_rank(parallel_context.world_pg)], device=torch.device("cuda"))
        test_tensor_list = [torch.zeros_like(test_tensor) for _ in range(parallel_context.world_pg.size())]
        dist.all_gather(test_tensor_list, test_tensor, group=parallel_context.world_pg, async_op=False)
        dist.barrier()
        logging.log_rank(
            f"[TEST] NCCL sync for ranks {[t.item() for t in test_tensor_list]}",
            logger=logger,
            level=logging.WARNING,
            group=parallel_context.dp_pg,
            rank=0,
        )

        del test_tensor_list
        del test_tensor
    else:
        logger.info("Not running in a parallel setup, nothing to test")