def meta_train_mountain_car()

in ml3/ml3_train.py [0:0]


def meta_train_mountain_car(policy,ml3_loss,task_loss_fn,s_0,goal,goal_extra,n_outer_iter,n_inner_iter,time_horizon,shaped_loss):
    s_0 = torch.Tensor(s_0)
    goal = torch.Tensor(goal)
    goal_extra = torch.Tensor(goal_extra)

    inner_opt = torch.optim.SGD(policy.parameters(), lr=policy.learning_rate)
    meta_opt = torch.optim.Adam(ml3_loss.parameters(), lr=ml3_loss.learning_rate)

    for outer_i in range(n_outer_iter):
        # set gradient with respect to meta loss parameters to 0
        meta_opt.zero_grad()
        for _ in range(n_inner_iter):
            inner_opt.zero_grad()
            with higher.innerloop_ctx(policy, inner_opt, copy_initial_weights=False) as (fpolicy, diffopt):
                # use current meta loss to update model
                s_tr, a_tr, g_tr = fpolicy.roll_out(s_0, goal, time_horizon)

                loss_input = torch.cat([s_tr[:-1], a_tr, g_tr], dim=1)
                pred_task_loss = ml3_loss(loss_input).mean()
                diffopt.step(pred_task_loss)

            # compute task loss
            s, a, g = fpolicy.roll_out(s_0, goal, time_horizon)
            task_loss = task_loss_fn(a, s[:], goal, goal_extra, shaped_loss)
            # backprop grad wrt to task loss
            task_loss.backward()

        meta_opt.step()

        if outer_i % 100 == 0:
            print("meta iter: {} loss: {}".format(outer_i, task_loss.item()))
            print('last state', s[-1])