in code/experiment_synthetic/models.py [0:0]
def train(self, environments, args, reg=0):
dim_x = environments[0][0].size(1)
self.phi = torch.nn.Parameter(torch.eye(dim_x, dim_x))
self.w = torch.ones(dim_x, 1)
self.w.requires_grad = True
opt = torch.optim.Adam([self.phi], lr=args["lr"])
loss = torch.nn.MSELoss()
for iteration in range(args["n_iterations"]):
penalty = 0
error = 0
for x_e, y_e in environments:
error_e = loss(x_e @ self.phi @ self.w, y_e)
penalty += grad(error_e, self.w,
create_graph=True)[0].pow(2).mean()
error += error_e
opt.zero_grad()
(reg * error + (1 - reg) * penalty).backward()
opt.step()
if args["verbose"] and iteration % 1000 == 0:
w_str = pretty(self.solution())
print("{:05d} | {:.5f} | {:.5f} | {:.5f} | {}".format(iteration,
reg,
error,
penalty,
w_str))