def meta_train_mbrl_reacher()

in ml3/ml3_train.py [0:0]


def meta_train_mbrl_reacher(policy, ml3_loss, dmodel, env, task_loss_fn, goals, n_outer_iter, n_inner_iter, time_horizon, exp_folder):
    goals = torch.Tensor(goals)

    meta_opt = torch.optim.Adam(ml3_loss.parameters(), lr=ml3_loss.learning_rate)

    for outer_i in range(n_outer_iter):
        # set gradient with respect to meta loss parameters to 0
        meta_opt.zero_grad()
        all_loss = 0
        for goal in goals:
            goal = torch.Tensor(goal)
            policy.reset()
            inner_opt = torch.optim.SGD(policy.parameters(), lr=policy.learning_rate)
            for _ in range(n_inner_iter):
                inner_opt.zero_grad()
                with higher.innerloop_ctx(policy, inner_opt, copy_initial_weights=False) as (fpolicy, diffopt):
                    # use current meta loss to update model
                    s_tr, a_tr, g_tr = fpolicy.roll_out(goal, time_horizon, dmodel, env)
                    meta_input = torch.cat([s_tr[:-1].detach(), a_tr, g_tr.detach()], dim=1)
                    pred_task_loss = ml3_loss(meta_input).mean()
                    diffopt.step(pred_task_loss)
                    # compute task loss
                    s, a, g = fpolicy.roll_out(goal, time_horizon, dmodel, env)
                    task_loss = task_loss_fn(a, s[:], goal).mean()

            # collect losses for logging
            all_loss += task_loss
            # backprop grad wrt to task loss
            task_loss.backward()

            if outer_i % 100 == 0:
                # roll out in real environment, to monitor training and tp collect data for dynamics model update
                states, actions, _ = fpolicy.roll_out(goal, time_horizon, dmodel, env, real_rollout=True)
                print("meta iter: {} loss: {}".format(outer_i, (torch.mean((states[-1,:2]-goal[:2])**2))))
                if outer_i % 300 == 0 and outer_i < 3001:
                    # update dynamics model under current optimal policy
                    dmodel.train(torch.Tensor(states), torch.Tensor(actions))

        # step optimizer to update meta loss network
        meta_opt.step()
        torch.save(ml3_loss.state_dict(), f'{exp_folder}/ml3_loss_reacher.pt')