def benchmark()

in xformers/benchmarks/LRA/run_tasks.py [0:0]


def benchmark(rank, args):
    # Setup multiprocessing
    dist.init_process_group(
        init_method="file://" + args.temp_file,
        backend="NCCL",
        rank=rank,
        world_size=args.world_size,
    )
    try:
        torch.cuda.set_device(args.gpu)
    except AttributeError:
        # Single node launcher
        torch.cuda.set_device(rank)

    task = args.task
    attention_name = args.attention

    # Build the problem
    log_f_path, logger = setup_log(args, rank, attention_name, task)
    args.logger = logger
    config = load_config(args.config)

    config_task = config[f"{task}"]
    if args.sweep_parameters is not None:
        logger.info("Replacing hyperparameters")
        rewrite_hyper(config_task, args.sweep_parameters)

    config_training = config_task["training"]
    config_training["seq_len"] = config_task["model"]["common"]["seq_len"]
    model = build_model(args, config)

    torch.manual_seed(config_training.get("seed", 0))  # also sets the cuda seed
    np.random.seed(config_training.get("seed", 0))
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.cuda.reset_peak_memory_stats()

    # tensorboard
    tb_logger = SummaryWriter(args.tb_dir)

    torch.manual_seed(config_training.get("seed", 0))  # also sets the cuda seed
    np.random.seed(config_training.get("seed", 0))
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.cuda.reset_peak_memory_stats()

    # tensorboard
    tb_logger = SummaryWriter(args.tb_dir)

    # Setup the training
    device_ids = list(range(torch.cuda.device_count()))
    logger.info(f"GPU list: {device_ids}")
    model = model.cuda()
    model = nn.parallel.DistributedDataParallel(
        model, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True
    )

    (
        datasets,
        samplers,
        optimizer,
        lr_scheduler,
        amp_scaler,
    ) = build_training_setup(config_training, task, model, rank, args.world_size)

    init_t = time.time()

    # Messenger structure which will be moved around to collect metrics
    summary = {
        comp: {
            "t": 0,
            "loss": [],
            "accu": [],
            "count": [],
            "best_accu": 0,
            "component": comp,
        }
        for comp in ["train", "dev", "test"]
    }

    # Setup the dataloaders
    accumu_steps = config_task["training"]["gradient_accumulation"]
    per_gpu_batch_size = (
        config_training["batch_size"] // args.world_size // accumu_steps
    )
    logging.warning(
        "Requested batch size: {}. Given world size and grad accumulation, per-gpu batch is {}".format(
            config_training["batch_size"], per_gpu_batch_size
        )
    )

    # reset train/eval steps if using gradient accumulation
    if accumu_steps > 1:
        config_training["num_train_steps"] *= accumu_steps
        config_training["num_eval_steps"] *= accumu_steps

    epochs = math.ceil(
        config_training["num_train_steps"]
        * config_training["batch_size"]
        / len(datasets["train"])
    )

    logging.warning(
        "Requested train steps: {}. Given dataset, this translates into {} epochs".format(
            config_training["num_train_steps"], epochs
        )
    )

    logger.info(f"accumu_steps={accumu_steps}")
    model_path = str(log_f_path).replace(".log", ".model")
    g = torch.Generator()
    g.manual_seed(config_training.get("seed", 0))

    dataloaders = {
        k: DataLoader(
            datasets[k],
            sampler=samplers[k],
            batch_size=per_gpu_batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=1,
            worker_init_fn=seed_worker,
            generator=g,
        )
        for k in datasets.keys()
    }

    # Our step function
    def step(
        batch: Dict[str, Any],
        component: str,
        step_idx: int,
        step_max: int,
        accumulate: bool = False,
    ):
        if step_idx > step_max:
            logger.warning(
                "Calling `step` beyond the training schedule, this is probably a mistake"
            )
            return

        t0 = time.time()
        batch_size = batch[list(batch.keys())[0]].size(0)

        for key in batch:
            batch[key] = batch[key].cuda()

        if component == "train":
            acc_context = model.no_sync() if accumulate else suppress()

            with acc_context, torch.autograd.set_detect_anomaly(args.debug):
                outputs = model(**batch)
                amp_scaler.scale(outputs["loss"]).backward()

                if not accumulate:
                    amp_scaler.step(optimizer)
                    optimizer.zero_grad()
                    amp_scaler.update()
                    lr_scheduler.step()

        else:
            with torch.no_grad():
                outputs = model(**batch)

        t1 = time.time()

        t_escape = t1 - t0
        learning_rate = optimizer.param_groups[0]["lr"]
        loss = outputs["loss"].item()
        accu = outputs["accu"].item()
        cnt = outputs["count"]
        time_since_start = time.time() - init_t
        eta = (
            datetime.timedelta(
                seconds=round(time_since_start / (step_idx + 1) * step_max)
            )
            if component == "train"
            else -1
        )

        if not step_idx % 10:
            logger.info(
                f"{component}: step={step_idx}/{step_max}, total_time={time_since_start:.1f},"
                + f" eta={eta},"
                + f" batch_time={t_escape:.3f}, bs={batch_size}, lr={learning_rate:.6f},"
                + f" loss={loss:.4f}, accu={accu:.4f}",
            )

        summary[component]["t"] += t_escape
        summary[component]["loss"].append(loss)
        summary[component]["accu"].append(accu)
        summary[component]["count"].append(cnt)

        if not accumulate:
            step_idx += 1

        return loss, step_idx

    # Start training or evaluating
    train_step_idx = 0
    if not args.skip_train:
        try:
            model.train()
            for epoch in range(epochs):
                logger.info(f"\nEpoch {epoch}")

                # Make sure that per-rank sampling is really random
                for sampler in samplers.values():
                    sampler.set_epoch(epoch)

                for i_batch, batch in enumerate(dataloaders["train"]):
                    grad_accumulate = (
                        i_batch % config_training["gradient_accumulation"] != 0
                    )

                    _, train_step_idx = step(
                        batch,
                        component="train",
                        step_idx=train_step_idx,
                        step_max=config_training["num_train_steps"],
                        accumulate=grad_accumulate,
                    )

                    if not (train_step_idx + 1) % config_training["eval_frequency"]:
                        print_summary(
                            summary["train"],
                            False,
                            train_step_idx,
                            model,
                            model_path,
                            logger,
                        )

                        eval_model(model, dataloaders, "dev", config_training, step)

                        print_summary(
                            summary["dev"],
                            True,
                            train_step_idx,
                            model,
                            model_path,
                            logger,
                            tb_logger,
                        )

                    if train_step_idx == config_training["num_train_steps"]:
                        break

        except KeyboardInterrupt as e:
            print(e)

    checkpoint = torch.load(model_path, map_location="cpu")
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    try:
        eval_model(model, dataloaders, "test", config_training, step)
    except StopIteration:
        pass

    print_summary(summary["test"], False, train_step_idx, model, model_path, logger)