def loss_with_reg()

in utils.py [0:0]


def loss_with_reg(model, data, target, loss_fn, lam):
    model.zero_grad()
    loss = loss_fn(model(data), target)
    if lam > 0:
        for param in model.parameters():
            loss += lam * param.pow(2).sum() / 2
    loss.backward()
    return loss