def train()

in covid19_spread/bar.py [0:0]


def train(model, new_cases, regions, optimizer, checkpoint, args):
    print(args)
    days_ahead = getattr(args, "days_ahead", 1)
    M = len(regions)
    device = new_cases.device
    tmax = new_cases.size(1)
    t = th.arange(tmax, device=device) + 1
    size_pred = tmax - days_ahead
    reg = th.tensor([0.0], device=device)
    target = new_cases.narrow(1, days_ahead, size_pred)

    start_time = timeit.default_timer()
    for itr in range(1, args.niters + 1):
        optimizer.zero_grad()
        scores, beta, W = model.score(t, new_cases)
        scores = scores.clamp(min=1e-8)
        assert scores.dim() == 2, scores.size()
        assert scores.size(1) == size_pred + 1
        assert beta.size(0) == M

        # compute loss
        dist = model.dist(scores.narrow(1, days_ahead - 1, size_pred))
        _loss = dist.log_prob(target)
        loss = -_loss.sum(axis=1).mean()

        stddev = model.dist(scores).stddev.mean()
        # loss += stddev * args.weight_decay

        # temporal smoothness
        if args.temporal > 0:
            reg = (
                args.temporal * th.pow(beta[:, 1:] - beta[:, :-1], 2).sum(axis=1).mean()
            )

        # back prop
        (loss + reg).backward()

        # do AdamW-like update for Granger regularization
        if args.granger > 0:
            with th.no_grad():
                mu = np.log(args.granger / (1 - args.granger))
                y = args.granger
                n = th.numel(model._alphas)
                ex = th.exp(-model._alphas)
                model._alphas.fill_diagonal_(mu)
                de = 2 * (model._alphas.sigmoid().mean() - y) * ex
                nu = n * (ex + 1) ** 2
                _grad = de / nu
                _grad.fill_diagonal_(0)
                r = args.lr * args.eta * n
                model._alphas.copy_(model._alphas - r * _grad)

        # make sure we have no NaNs
        assert loss == loss, (loss, scores, _loss)

        nn.utils.clip_grad_norm_(model.parameters(), 5)
        # take gradient step
        optimizer.step()

        # control
        if itr % 500 == 0:
            time = timeit.default_timer() - start_time
            with th.no_grad(), np.printoptions(precision=3, suppress=True):
                length = scores.size(1) - 1
                maes = th.abs(dist.mean - new_cases.narrow(1, 1, length))
                z = model.z
                nu = th.sigmoid(model.nu)
                means = model.dist(scores).mean
                W_spread = (W * (1 - W)).mean()
                _err = W.mean() - args.granger
                print(
                    f"[{itr:04d}] Loss {loss.item():.2f} | "
                    f"Temporal {reg.item():.5f} | "
                    f"MAE {maes.mean():.2f} | "
                    f"{model} | "
                    f"{args.loss} ({means[:, -1].min().item():.2f}, {means[:, -1].max().item():.2f}) | "
                    f"z ({z.min().item():.2f}, {z.mean().item():.2f}, {z.max().item():.2f}) | "
                    f"W ({W.min().item():.2f}, {W.mean().item():.2f}, {W.max().item():.2f}) | "
                    f"W_spread {W_spread:.2f} | mu_err {_err:.3f} | "
                    f"nu ({nu.min().item():.2f}, {nu.mean().item():.2f}, {nu.max().item():.2f}) | "
                    f"nb_stddev ({stddev.data.mean().item():.2f}) | "
                    f"scale ({th.exp(model.scale).mean():.2f}) | "
                    f"time = {time:.2f}s"
                )
                th.save(model.state_dict(), checkpoint)
                start_time = timeit.default_timer()
    print(f"Train MAE,{maes.mean():.2f}")
    return model