in covid19_spread/bar.py [0:0]
def train(model, new_cases, regions, optimizer, checkpoint, args):
print(args)
days_ahead = getattr(args, "days_ahead", 1)
M = len(regions)
device = new_cases.device
tmax = new_cases.size(1)
t = th.arange(tmax, device=device) + 1
size_pred = tmax - days_ahead
reg = th.tensor([0.0], device=device)
target = new_cases.narrow(1, days_ahead, size_pred)
start_time = timeit.default_timer()
for itr in range(1, args.niters + 1):
optimizer.zero_grad()
scores, beta, W = model.score(t, new_cases)
scores = scores.clamp(min=1e-8)
assert scores.dim() == 2, scores.size()
assert scores.size(1) == size_pred + 1
assert beta.size(0) == M
# compute loss
dist = model.dist(scores.narrow(1, days_ahead - 1, size_pred))
_loss = dist.log_prob(target)
loss = -_loss.sum(axis=1).mean()
stddev = model.dist(scores).stddev.mean()
# loss += stddev * args.weight_decay
# temporal smoothness
if args.temporal > 0:
reg = (
args.temporal * th.pow(beta[:, 1:] - beta[:, :-1], 2).sum(axis=1).mean()
)
# back prop
(loss + reg).backward()
# do AdamW-like update for Granger regularization
if args.granger > 0:
with th.no_grad():
mu = np.log(args.granger / (1 - args.granger))
y = args.granger
n = th.numel(model._alphas)
ex = th.exp(-model._alphas)
model._alphas.fill_diagonal_(mu)
de = 2 * (model._alphas.sigmoid().mean() - y) * ex
nu = n * (ex + 1) ** 2
_grad = de / nu
_grad.fill_diagonal_(0)
r = args.lr * args.eta * n
model._alphas.copy_(model._alphas - r * _grad)
# make sure we have no NaNs
assert loss == loss, (loss, scores, _loss)
nn.utils.clip_grad_norm_(model.parameters(), 5)
# take gradient step
optimizer.step()
# control
if itr % 500 == 0:
time = timeit.default_timer() - start_time
with th.no_grad(), np.printoptions(precision=3, suppress=True):
length = scores.size(1) - 1
maes = th.abs(dist.mean - new_cases.narrow(1, 1, length))
z = model.z
nu = th.sigmoid(model.nu)
means = model.dist(scores).mean
W_spread = (W * (1 - W)).mean()
_err = W.mean() - args.granger
print(
f"[{itr:04d}] Loss {loss.item():.2f} | "
f"Temporal {reg.item():.5f} | "
f"MAE {maes.mean():.2f} | "
f"{model} | "
f"{args.loss} ({means[:, -1].min().item():.2f}, {means[:, -1].max().item():.2f}) | "
f"z ({z.min().item():.2f}, {z.mean().item():.2f}, {z.max().item():.2f}) | "
f"W ({W.min().item():.2f}, {W.mean().item():.2f}, {W.max().item():.2f}) | "
f"W_spread {W_spread:.2f} | mu_err {_err:.3f} | "
f"nu ({nu.min().item():.2f}, {nu.mean().item():.2f}, {nu.max().item():.2f}) | "
f"nb_stddev ({stddev.data.mean().item():.2f}) | "
f"scale ({th.exp(model.scale).mean():.2f}) | "
f"time = {time:.2f}s"
)
th.save(model.state_dict(), checkpoint)
start_time = timeit.default_timer()
print(f"Train MAE,{maes.mean():.2f}")
return model