def main()

in main.py [0:0]


def main(args):
    args = copy.deepcopy(args)
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    update_args(args)

    distributed.init(args)
    args.device = torch.device("cuda" if use_cuda else "cpu")
    logger = Logger(args)
    logger.print(f"PyTorch version: {torch.__version__}")
    logger.print(f"PyTorch CUDA version: {torch.version.cuda}")
    logger.print(str(args))

    # load data
    train_data, val_data, test_data, corpus = data.get_data(args, logger, args.data_eos)
    if len(args.data_omit_labels) > 0:
        args.data_omit_label_idx = [
            corpus.dictionary.word2idx[w] for w in args.data_omit_labels
        ]
    else:
        args.data_omit_label_idx = None

    # create a model
    if args.feedback:
        model = feedback.FeedbackTransformer(args)
    elif args.expire_span:
        model = expire_span.ExpireSpan(args)
    elif args.compress:
        model = compressive.CompressiveTransformer(args)
    else:
        model = transformer_seq.TransformerSeq(args)
    model.to(args.device)

    # count params
    nparameters = 0
    params = []
    for param in model.parameters():
        if param.requires_grad:
            nparameters += param.numel()
            params.append(param)
    logger.print("nparameters={:.2f}M".format(nparameters / 1e6))

    # OPTIM param
    if args.optim == "sgd":
        optimizer = optim.SGD(params, lr=args.lr, momentum=args.momentum)
    elif args.optim == "adam":
        optimizer = optim.Adam(params, lr=args.lr)

    if args.lr_decay:
        # will do warm-up manually later
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, args.nepochs * args.nbatches
        )
    elif args.lr_warmup > 0:
        scheduler = optim.lr_scheduler.LambdaLR(
            optimizer, lambda ep: min(1, ep / args.lr_warmup)
        )
    else:
        scheduler = None

    model = distributed.wrap_model(args, model)

    ep_init = checkpoint.load(args, model, optimizer, logger, scheduler)

    # pos: data samling 0=sequential, -1=random
    pos = [0 for _ in range(3)]
    if isinstance(train_data, tuple):
        pos[0] = random.randrange(train_data[0].size(1) - args.mem_sz)
    else:
        pos[0] = random.randrange(train_data.size(1) - args.mem_sz)
    hid_cache = [
        model.module.init_hid_cache(args.batch_sz),
        model.module.init_hid_cache(args.test_batch_sz),
        model.module.init_hid_cache(args.test_batch_sz),
    ]

    if args.full_test:
        # perform evaluation only
        with torch.no_grad():
            stat_val, pos[1], hid_cache[1] = train(
                args,
                model,
                optimizer,
                scheduler,
                val_data,
                test_only=True,
                train_pos=pos[1],
                h_cache=hid_cache[1],
                corpus=corpus,
            )
            stat_test, pos[2], hid_cache[2] = train(
                args,
                model,
                optimizer,
                scheduler,
                test_data,
                test_only=True,
                train_pos=pos[2],
                h_cache=hid_cache[2],
                corpus=corpus,
            )
            gpu_mem = torch.cuda.max_memory_allocated() / 1024 ** 3
            stat_test, stat_val, gpu_mem = distributed.collect_stat(
                args, stat_test, stat_val, gpu_mem
            )
            if args.data_type == "char":
                if "err" in stat_val:
                    logger.print("val err: {:.3f}%".format(stat_val["err"] * 100))
                    logger.print("test err: {:.3f}%".format(stat_test["err"] * 100))
                else:
                    logger.print(
                        "val: {:.3f}bpc".format(stat_val["loss"] / math.log(2))
                    )
                    logger.print(
                        "test: {:.3f}bpc".format(stat_test["loss"] / math.log(2))
                    )
            else:
                logger.print("val: {:.3f}ppl".format(math.exp(stat_val["loss"])))
                logger.print("test: {:.3f}ppl".format(math.exp(stat_test["loss"])))
            logger.print(f"gpu_mem: {gpu_mem:.1f}gb")
        return

    for ep in range(ep_init, args.nepochs):
        t_sta = time.time()
        args.ep = ep
        stat_train, pos[0], hid_cache[0] = train(
            args,
            model,
            optimizer,
            scheduler,
            train_data,
            train_pos=pos[0],
            h_cache=hid_cache[0],
            corpus=corpus,
        )
        elapsed = 1000 * (time.time() - t_sta) / args.nbatches
        with torch.no_grad():
            if args.full_valid:
                stat_val, _, _ = train(
                    args,
                    model,
                    optimizer,
                    scheduler,
                    val_data,
                    test_only=True,
                    train_pos=pos[1],
                    h_cache=hid_cache[1],
                    corpus=corpus,
                )
            else:
                stat_val, pos[1], hid_cache[1] = train(
                    args,
                    model,
                    optimizer,
                    scheduler,
                    val_data,
                    test_only=True,
                    train_pos=pos[1],
                    h_cache=hid_cache[1],
                    corpus=corpus,
                )

        gpu_mem = torch.cuda.max_memory_allocated() / 1024 ** 3
        torch.cuda.reset_max_memory_allocated()
        stat_train, stat_val, gpu_mem = distributed.collect_stat(
            args, stat_train, stat_val, gpu_mem
        )

        if args.rank == 0:
            # only the master process will do logging, plotting and checkpoint
            if args.lr_decay:
                logger.log("compute/lr", optimizer.param_groups[0]["lr"])
            if args.adapt_span:
                adaptive_span.log(args, model, logger, stat_train)
            if args.expire_span:
                expire_span.log(args, model, logger, stat_train)
            if args.feedback:
                feedback.log(args, model, logger, stat_train)

            logger.step(args, stat_train, stat_val, elapsed, gpu_mem)
            checkpoint.save(args, model, optimizer, logger, scheduler)