in mbirl/experiments/run_model_based_irl.py [0:0]
def evaluate_action_optimization(learned_cost, robot_model, irl_loss_fn, trajs, n_inner_iter, action_lr=0.001):
# np.random.seed(cfg.random_seed)
# torch.manual_seed(cfg.random_seed)
eval_costs = []
for i, traj in enumerate(trajs):
traj_len = len(traj['desired_keypoints'])
start_pose = traj['start_joint_config'].squeeze()
expert_demo = traj['desired_keypoints'].reshape(traj_len, -1)
expert_demo = torch.Tensor(expert_demo)
time_horizon, n_keypt_dim = expert_demo.shape
keypoint_mpc_wrapper = GroundTruthKeypointMPCWrapper(robot_model, time_horizon=time_horizon - 1, n_keypt_dim=n_keypt_dim)
action_optimizer = torch.optim.SGD(keypoint_mpc_wrapper.parameters(), lr=action_lr)
for i in range(n_inner_iter):
action_optimizer.zero_grad()
pred_traj = keypoint_mpc_wrapper.roll_out(start_pose.clone())
# use the learned loss to update the action sequence
learned_cost_val = learned_cost(pred_traj, expert_demo[-1])
learned_cost_val.backward(retain_graph=True)
action_optimizer.step()
# Actually take the next step after optimizing the action
pred_state_traj_new = keypoint_mpc_wrapper.roll_out(start_pose.clone())
eval_costs.append(irl_loss_fn(pred_state_traj_new, expert_demo).mean())
return torch.stack(eval_costs).detach()