in ml3/ml3_train.py [0:0]
def meta_train_mbrl_reacher(policy, ml3_loss, dmodel, env, task_loss_fn, goals, n_outer_iter, n_inner_iter, time_horizon, exp_folder):
goals = torch.Tensor(goals)
meta_opt = torch.optim.Adam(ml3_loss.parameters(), lr=ml3_loss.learning_rate)
for outer_i in range(n_outer_iter):
# set gradient with respect to meta loss parameters to 0
meta_opt.zero_grad()
all_loss = 0
for goal in goals:
goal = torch.Tensor(goal)
policy.reset()
inner_opt = torch.optim.SGD(policy.parameters(), lr=policy.learning_rate)
for _ in range(n_inner_iter):
inner_opt.zero_grad()
with higher.innerloop_ctx(policy, inner_opt, copy_initial_weights=False) as (fpolicy, diffopt):
# use current meta loss to update model
s_tr, a_tr, g_tr = fpolicy.roll_out(goal, time_horizon, dmodel, env)
meta_input = torch.cat([s_tr[:-1].detach(), a_tr, g_tr.detach()], dim=1)
pred_task_loss = ml3_loss(meta_input).mean()
diffopt.step(pred_task_loss)
# compute task loss
s, a, g = fpolicy.roll_out(goal, time_horizon, dmodel, env)
task_loss = task_loss_fn(a, s[:], goal).mean()
# collect losses for logging
all_loss += task_loss
# backprop grad wrt to task loss
task_loss.backward()
if outer_i % 100 == 0:
# roll out in real environment, to monitor training and tp collect data for dynamics model update
states, actions, _ = fpolicy.roll_out(goal, time_horizon, dmodel, env, real_rollout=True)
print("meta iter: {} loss: {}".format(outer_i, (torch.mean((states[-1,:2]-goal[:2])**2))))
if outer_i % 300 == 0 and outer_i < 3001:
# update dynamics model under current optimal policy
dmodel.train(torch.Tensor(states), torch.Tensor(actions))
# step optimizer to update meta loss network
meta_opt.step()
torch.save(ml3_loss.state_dict(), f'{exp_folder}/ml3_loss_reacher.pt')