in scripts/models.py [0:0]
def fit(self, envs, num_iterations, callback=False):
for epoch in range(num_iterations):
losses = [self.loss(self.network(x), y)
for x, y in envs["train"]["envs"]]
self.mask_step(
losses, list(self.parameters()),
tau=self.hparams["tau"],
wd=self.hparams["wd"],
lr=self.hparams["lr"]
)
if callback:
# compute errors
utils.compute_errors(self, envs)