def main_subproc()

in fairdiplomacy/models/diplomacy_model/train_sl.py [0:0]


def main_subproc(rank, world_size, args, train_set, val_set, extra_val_datasets):
    has_gpu = torch.cuda.is_available()
    if has_gpu:
        # distributed training setup
        mp_setup(rank, world_size)
        atexit.register(mp_cleanup)
        torch.cuda.set_device(rank)
    else:
        assert rank == 0 and world_size == 1

    metric_logger = Logger(is_master=rank == 0)
    global_step = 0
    log_scalars = lambda **scalars: metric_logger.log_metrics(
        scalars, step=global_step, sanitize=True
    )

    # load checkpoint if specified
    if args.checkpoint and os.path.isfile(args.checkpoint):
        logger.info("Loading checkpoint at {}".format(args.checkpoint))
        checkpoint = torch.load(args.checkpoint, map_location="cuda:{}".format(rank))
    else:
        checkpoint = None

    logger.info("Init model...")
    net = new_model(args)

    # send model to GPU
    if has_gpu:
        logger.debug("net.cuda({})".format(rank))
        net.cuda(rank)
        logger.debug("net {} DistributedDataParallel".format(rank))
        net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[rank])
        logger.debug("net {} DistributedDataParallel done".format(rank))

    # load from checkpoint if specified
    if checkpoint:
        logger.debug("net.load_state_dict")
        net.load_state_dict(checkpoint["model"], strict=True)

    # create optimizer, from checkpoint if specified
    policy_loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
    value_loss_fn = torch.nn.MSELoss(reduction="none")
    optim = torch.optim.Adam(net.parameters(), lr=args.lr)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=1, gamma=args.lr_decay)
    if checkpoint:
        optim.load_state_dict(checkpoint["optim"])

    # load best losses to not immediately overwrite best checkpoints
    best_loss = checkpoint.get("best_loss") if checkpoint else None
    best_p_loss = checkpoint.get("best_p_loss") if checkpoint else None
    best_v_loss = checkpoint.get("best_v_loss") if checkpoint else None

    if has_gpu:
        train_set_sampler = DistributedSampler(train_set)
    else:
        train_set_sampler = RandomSampler(train_set)

    for epoch in range(checkpoint["epoch"] + 1 if checkpoint else 0, args.num_epochs):
        if has_gpu:
            train_set_sampler.set_epoch(epoch)
        batches = torch.tensor(list(iter(train_set_sampler)), dtype=torch.long).split(
            args.batch_size
        )
        for batch_i, batch_idxs in enumerate(batches):
            batch = train_set[batch_idxs]
            logger.debug(f"Zero grad {batch_i} ...")

            # check batch is not empty
            if (batch["y_actions"] == EOS_IDX).all():
                logger.warning("Skipping empty epoch {} batch {}".format(epoch, batch_i))
                continue

            # learn
            logger.debug("Starting epoch {} batch {}".format(epoch, batch_i))
            optim.zero_grad()
            policy_losses, value_losses, _, _ = process_batch(
                net,
                batch,
                policy_loss_fn,
                value_loss_fn,
                p_teacher_force=args.teacher_force,
                shuffle_locs=args.shuffle_locs,
            )

            # backward
            p_loss = torch.mean(policy_losses)
            v_loss = torch.mean(value_losses)
            loss = (1 - args.value_loss_weight) * p_loss + args.value_loss_weight * v_loss
            loss.backward()

            # clip gradients, step
            value_decoder_grad_norm = torch.nn.utils.clip_grad_norm_(
                getattr(net, "module", net).value_decoder.parameters(),
                args.value_decoder_clip_grad_norm,
            )
            grad_norm = torch.nn.utils.clip_grad_norm_(net.parameters(), args.clip_grad_norm)
            optim.step()

            # log diagnostics
            if rank == 0 and batch_i % 10 == 0:
                scalars = dict(
                    epoch=epoch,
                    batch=batch_i,
                    loss=loss,
                    lr=optim.state_dict()["param_groups"][0]["lr"],
                    grad_norm=grad_norm,
                    value_decoder_grad_norm=value_decoder_grad_norm,
                    p_loss=p_loss,
                    v_loss=v_loss,
                )
                log_scalars(**scalars)
                logger.info(
                    "epoch {} batch {} / {}, ".format(epoch, batch_i, len(batches))
                    + " ".join(f"{k}= {v}" for k, v in scalars.items())
                )
            global_step += 1
            if args.epoch_max_batches and batch_i + 1 >= args.epoch_max_batches:
                logging.info("Exiting early due to epoch_max_batches")
                break

        # calculate validation loss/accuracy
        if not args.skip_validation and rank == 0:
            logger.info("Calculating val loss...")
            (
                valid_loss,
                valid_p_loss,
                valid_v_loss,
                valid_p_accuracy,
                valid_v_accuracy,
                split_pcts,
            ) = validate(
                net,
                val_set,
                policy_loss_fn,
                value_loss_fn,
                args.batch_size,
                value_loss_weight=args.value_loss_weight,
            )
            scalars = dict(
                epoch=epoch,
                valid_loss=valid_loss,
                valid_p_loss=valid_p_loss,
                valid_v_loss=valid_v_loss,
                valid_p_accuracy=valid_p_accuracy,
                valid_v_accuracy=valid_v_accuracy,
            )
            for name, extra_val_set in extra_val_datasets.items():
                (
                    scalars[f"valid_{name}/loss"],
                    scalars[f"valid_{name}/p_loss"],
                    scalars[f"valid_{name}/v_loss"],
                    scalars[f"valid_{name}/p_accuracy"],
                    scalars[f"valid_{name}/v_accuracy"],
                    _,
                ) = validate(
                    net,
                    extra_val_set,
                    policy_loss_fn,
                    value_loss_fn,
                    args.batch_size,
                    value_loss_weight=args.value_loss_weight,
                )

            log_scalars(**scalars)
            logger.info("Validation " + " ".join([f"{k}= {v}" for k, v in scalars.items()]))
            for k, v in sorted(split_pcts.items()):
                logger.info(f"val split epoch= {epoch} batch= {batch_i}: {k} = {v}")

            # save model
            if args.checkpoint and rank == 0:
                obj = {
                    "model": net.state_dict(),
                    "optim": optim.state_dict(),
                    "epoch": epoch,
                    "batch_i": batch_i,
                    "valid_p_accuracy": valid_p_accuracy,
                    "args": args,
                    "best_loss": best_loss,
                    "best_p_loss": best_p_loss,
                    "best_v_loss": best_v_loss,
                }
                logger.info("Saving checkpoint to {}".format(args.checkpoint))
                torch.save(obj, args.checkpoint)

                if epoch % 10 == 0:
                    torch.save(obj, args.checkpoint + ".epoch_" + str(epoch))
                if best_loss is None or valid_loss < best_loss:
                    best_loss = valid_loss
                    torch.save(obj, args.checkpoint + ".best")
                if best_p_loss is None or valid_p_loss < best_p_loss:
                    best_p_loss = valid_p_loss
                    torch.save(obj, args.checkpoint + ".bestp")
                if best_v_loss is None or valid_v_loss < best_v_loss:
                    best_v_loss = valid_v_loss
                    torch.save(obj, args.checkpoint + ".bestv")

        lr_scheduler.step()