in scripts/models.py [0:0]
def fit(self, envs, num_iterations, callback=False):
x = torch.cat([xe for xe, ye in envs["train"]["envs"]])
y = torch.cat([ye for xe, ye in envs["train"]["envs"]])
for epoch in range(num_iterations):
self.optimizer.zero_grad()
self.loss(self.network(x), y).backward()
self.optimizer.step()
if callback:
# compute errors
utils.compute_errors(self, envs)