in ml3/ml3_train.py [0:0]
def meta_train_mountain_car(policy,ml3_loss,task_loss_fn,s_0,goal,goal_extra,n_outer_iter,n_inner_iter,time_horizon,shaped_loss):
s_0 = torch.Tensor(s_0)
goal = torch.Tensor(goal)
goal_extra = torch.Tensor(goal_extra)
inner_opt = torch.optim.SGD(policy.parameters(), lr=policy.learning_rate)
meta_opt = torch.optim.Adam(ml3_loss.parameters(), lr=ml3_loss.learning_rate)
for outer_i in range(n_outer_iter):
# set gradient with respect to meta loss parameters to 0
meta_opt.zero_grad()
for _ in range(n_inner_iter):
inner_opt.zero_grad()
with higher.innerloop_ctx(policy, inner_opt, copy_initial_weights=False) as (fpolicy, diffopt):
# use current meta loss to update model
s_tr, a_tr, g_tr = fpolicy.roll_out(s_0, goal, time_horizon)
loss_input = torch.cat([s_tr[:-1], a_tr, g_tr], dim=1)
pred_task_loss = ml3_loss(loss_input).mean()
diffopt.step(pred_task_loss)
# compute task loss
s, a, g = fpolicy.roll_out(s_0, goal, time_horizon)
task_loss = task_loss_fn(a, s[:], goal, goal_extra, shaped_loss)
# backprop grad wrt to task loss
task_loss.backward()
meta_opt.step()
if outer_i % 100 == 0:
print("meta iter: {} loss: {}".format(outer_i, task_loss.item()))
print('last state', s[-1])