def irl_training()

in mbirl/experiments/run_model_based_irl.py [0:0]


def irl_training(learnable_cost, robot_model, irl_loss_fn, train_trajs, test_trajs, n_outer_iter, n_inner_iter,
                 data_type, cost_type, cost_lr=1e-2, action_lr=1e-3):
    irl_loss_on_train = []
    irl_loss_on_test = []

    learnable_cost_opt = torch.optim.Adam(learnable_cost.parameters(), lr=cost_lr)

    irl_loss_dems = []
    # initial loss before training

    plots_dir = os.path.join(model_data_dir, data_type, cost_type)

    if not os.path.exists(plots_dir):
        os.makedirs(plots_dir)

    for demo_i in range(len(train_trajs)):
        expert_demo_dict = train_trajs[demo_i]

        start_pose = expert_demo_dict['start_joint_config'].squeeze()
        expert_demo = expert_demo_dict['desired_keypoints'].reshape(traj_len, -1)
        expert_demo = torch.Tensor(expert_demo)
        time_horizon, n_keypt_dim = expert_demo.shape

        keypoint_mpc_wrapper = GroundTruthKeypointMPCWrapper(robot_model, time_horizon=time_horizon - 1, n_keypt_dim=n_keypt_dim)
        # unroll and extract expected features
        pred_traj = keypoint_mpc_wrapper.roll_out(start_pose.clone())

        # get initial irl loss
        irl_loss = irl_loss_fn(pred_traj, expert_demo).mean()
        irl_loss_dems.append(irl_loss.item())

    irl_loss_on_train.append(torch.Tensor(irl_loss_dems).mean())
    print("irl cost training iter: {} loss: {}".format(0, irl_loss_on_train[-1]))

    print("Cost function parameters to be optimized:")
    for name, param in learnable_cost.named_parameters():
        print(name)
        print(param)

    # start of inverse RL loop
    for outer_i in range(n_outer_iter):
        irl_loss_dems = []

        for demo_i in range(len(train_trajs)):
            learnable_cost_opt.zero_grad()
            expert_demo_dict = train_trajs[demo_i]

            start_pose = expert_demo_dict['start_joint_config'].squeeze()
            expert_demo = expert_demo_dict['desired_keypoints'].reshape(traj_len, -1)
            expert_demo = torch.Tensor(expert_demo)
            time_horizon, n_keypt_dim = expert_demo.shape

            keypoint_mpc_wrapper = GroundTruthKeypointMPCWrapper(robot_model, time_horizon=time_horizon - 1,
                                                                 n_keypt_dim=n_keypt_dim)
            action_optimizer = torch.optim.SGD(keypoint_mpc_wrapper.parameters(), lr=action_lr)

            with higher.innerloop_ctx(keypoint_mpc_wrapper, action_optimizer) as (fpolicy, diffopt):
                pred_traj = fpolicy.roll_out(start_pose.clone())

                # use the learned loss to update the action sequence
                learned_cost_val = learnable_cost(pred_traj, expert_demo[-1])
                diffopt.step(learned_cost_val)

                pred_traj = fpolicy.roll_out(start_pose)
                # compute task loss
                irl_loss = irl_loss_fn(pred_traj, expert_demo).mean()
                # backprop gradient of learned cost parameters wrt irl loss
                irl_loss.backward(retain_graph=True)
                irl_loss_dems.append(irl_loss.detach())

            learnable_cost_opt.step()

            if outer_i % 25 == 0:
                plt.figure()
                plt.plot(pred_traj[:, 7].detach(), pred_traj[:, 9].detach(), 'o')
                plt.plot(expert_demo[:, 0], expert_demo[:, 2], 'x')
                plt.title("outer i: {}".format(outer_i))
                plt.savefig(os.path.join(plots_dir, f'{demo_i}_{outer_i}.png'))

        irl_loss_on_train.append(torch.Tensor(irl_loss_dems).mean())
        test_irl_losses = evaluate_action_optimization(learnable_cost.eval(), robot_model, irl_loss_fn, test_trajs,
                                                  n_inner_iter)
        print("irl loss (on train) training iter: {} loss: {}".format(outer_i + 1, irl_loss_on_train[-1]))
        print("irl loss (on test) training iter: {} loss: {}".format(outer_i + 1, test_irl_losses.mean().item()))
        print("")
        irl_loss_on_test.append(test_irl_losses)
        learnable_cost_params = {}
        for name, param in learnable_cost.named_parameters():
            learnable_cost_params[name] = param

        if len(learnable_cost_params) == 0:
            # For RBF Weighted Cost
            for name, param in learnable_cost.weights_fn.named_parameters():
                learnable_cost_params[name] = param

    plt.figure()
    plt.plot(pred_traj[:, 7].detach(), pred_traj[:, 9].detach(), 'o')
    plt.plot(expert_demo[:, 0], expert_demo[:, 2], 'x')
    plt.title("final")
    plt.savefig(os.path.join(plots_dir, f'{demo_i}_final.png'))

    return torch.stack(irl_loss_on_train), torch.stack(irl_loss_on_test), learnable_cost_params, pred_traj