def train()

in train.py [0:0]


def train(world_rank, args):
    # setup logging
    level = logging.INFO
    if world_rank != 0:
        level = logging.CRITICAL
    logging.getLogger().setLevel(level)

    with open(args.config, "r") as fid:
        config = json.load(fid)
        logging.info("Using the config \n{}".format(json.dumps(config)))

    is_distributed_train = False
    if args.world_size > 1:
        is_distributed_train = True
        torch.distributed.init_process_group(
            backend=args.dist_backend,
            init_method=args.dist_url,
            world_size=args.world_size,
            rank=world_rank,
        )

    if not args.disable_cuda:
        device = torch.device("cuda")
        torch.cuda.set_device(world_rank)
    else:
        device = torch.device("cpu")

    # seed everything:
    seed = config.get("seed", None)
    if seed is not None:
        torch.manual_seed(seed)

    # setup data loaders:
    logging.info("Loading dataset ...")
    dataset = config["data"]["dataset"]
    if not os.path.exists(f"datasets/{dataset}.py"):
        raise ValueError(f"Unknown dataset {dataset}")
    dataset = utils.module_from_file("dataset", f"datasets/{dataset}.py")

    input_size = config["data"]["num_features"]
    data_path = config["data"]["data_path"]
    preprocessor = dataset.Preprocessor(
        data_path,
        num_features=input_size,
        tokens_path=config["data"].get("tokens", None),
        lexicon_path=config["data"].get("lexicon", None),
        use_words=config["data"].get("use_words", False),
        prepend_wordsep=config["data"].get("prepend_wordsep", False),
    )
    trainset = dataset.Dataset(data_path, preprocessor, split="train", augment=True)
    valset = dataset.Dataset(data_path, preprocessor, split="validation")
    train_loader = utils.data_loader(trainset, config, world_rank, args.world_size)
    val_loader = utils.data_loader(valset, config, world_rank, args.world_size)

    # setup criterion, model:
    logging.info("Loading model ...")
    criterion, output_size = models.load_criterion(
        config.get("criterion_type", "ctc"),
        preprocessor,
        config.get("criterion", {}),
    )
    criterion = criterion.to(device)
    model = models.load_model(
        config["model_type"], input_size, output_size, config["model"]
    ).to(device)

    if args.restore:
        models.load_from_checkpoint(model, criterion, args.checkpoint_path, True)

    n_params = sum(p.numel() for p in model.parameters())
    logging.info(
        "Training {} model with {:,} parameters.".format(config["model_type"], n_params)
    )

    # Store base module, criterion for saving checkpoints
    base_model = model
    base_criterion = criterion  # `decode` cannot be called on DDP module
    if is_distributed_train:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[world_rank]
        )

        if len(list(criterion.parameters())) > 0:
            criterion = torch.nn.parallel.DistributedDataParallel(
                criterion, device_ids=[world_rank]
            )

    epochs = config["optim"]["epochs"]
    lr = config["optim"]["learning_rate"]
    step_size = config["optim"]["step_size"]
    max_grad_norm = config["optim"].get("max_grad_norm", None)

    # run training:
    logging.info("Starting training ...")
    scale = 0.5 ** (args.last_epoch // step_size)
    params = [{"params" : model.parameters(),
               "initial_lr" : lr * scale,
               "lr" : lr * scale}]
    if len(list(criterion.parameters())) > 0:
        crit_params = {"params" : criterion.parameters()}
        crit_lr = config["optim"].get("crit_learning_rate", lr)
        crit_params['lr'] = crit_lr * scale
        crit_params['initial_lr'] = crit_lr * scale
        params.append(crit_params)

    optimizer = torch.optim.SGD(params)
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, step_size=step_size, gamma=0.5,
        last_epoch=args.last_epoch,
    )
    min_val_loss = float("inf")
    min_val_cer = float("inf")
    min_val_wer = float("inf")

    Timer = utils.CudaTimer if device.type == "cuda" else utils.Timer
    timers = Timer(
        [
            "ds_fetch",  # dataset sample fetch
            "model_fwd",  # model forward
            "crit_fwd",  # criterion forward
            "bwd",  # backward (model + criterion)
            "optim",  # optimizer step
            "metrics",  # viterbi, cer
            "train_total",  # total training
            "test_total",  # total testing
        ]
    )
    num_updates = 0
    for epoch in range(args.last_epoch, epochs):
        logging.info("Epoch {} started. ".format(epoch + 1))
        model.train()
        criterion.train()
        start_time = time.time()
        meters = utils.Meters()
        timers.reset()
        timers.start("train_total").start("ds_fetch")
        for inputs, targets in train_loader:
            timers.stop("ds_fetch").start("model_fwd")
            optimizer.zero_grad()
            outputs = model(inputs.to(device))
            timers.stop("model_fwd").start("crit_fwd")
            loss = criterion(outputs, targets)
            timers.stop("crit_fwd").start("bwd")
            loss.backward()
            timers.stop("bwd").start("optim")
            if max_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(
                    itertools.chain(model.parameters(), criterion.parameters()),
                    max_grad_norm,
                )
            optimizer.step()
            num_updates += 1
            timers.stop("optim").start("metrics")
            meters.loss += loss.item() * len(targets)
            meters.num_samples += len(targets)
            tokens_dist, words_dist, n_tokens, n_words = compute_edit_distance(
                base_criterion.viterbi(outputs), targets, preprocessor
            )
            meters.edit_distance_tokens += tokens_dist
            meters.num_tokens += n_tokens
            meters.edit_distance_words += words_dist
            meters.num_words += n_words
            timers.stop("metrics").start("ds_fetch")
        timers.stop("ds_fetch").stop("train_total")
        epoch_time = time.time() - start_time
        if args.world_size > 1:
            meters.sync()
        logging.info(
            "Epoch {} complete. "
            "nUpdates {}, Loss {:.3f}, CER {:.3f}, WER {:.3f},"
            " Time {:.3f} (s), LR {:.3f}".format(
                epoch + 1,
                num_updates,
                meters.avg_loss,
                meters.cer,
                meters.wer,
                epoch_time,
                scheduler.get_last_lr()[0],
            ),
        )
        logging.info("Evaluating validation set..")
        timers.start("test_total")
        val_loss, val_cer, val_wer = test(
            model, base_criterion, val_loader, preprocessor, device, args.world_size
        )
        timers.stop("test_total")
        if world_rank == 0:
            checkpoint(
                base_model,
                base_criterion,
                args.checkpoint_path,
                (val_cer < min_val_cer),
            )

            min_val_loss = min(val_loss, min_val_loss)
            min_val_cer = min(val_cer, min_val_cer)
            min_val_wer = min(val_wer, min_val_wer)
        logging.info(
            "Validation Set: Loss {:.3f}, CER {:.3f}, WER {:.3f}, "
            "Best Loss {:.3f}, Best CER {:.3f}, Best WER {:.3f}".format(
                val_loss, val_cer, val_wer, min_val_loss, min_val_cer, min_val_wer
            ),
        )
        logging.info(
            "Timing Info: "
            + ", ".join(
                [
                    "{} : {:.2f}ms".format(k, v * 1000.0)
                    for k, v in timers.value().items()
                ]
            )
        )
        scheduler.step()
        start_time = time.time()

    if is_distributed_train:
        torch.distributed.destroy_process_group()