in mbirl/experiments/run_model_based_irl.py [0:0]
def irl_training(learnable_cost, robot_model, irl_loss_fn, train_trajs, test_trajs, n_outer_iter, n_inner_iter,
data_type, cost_type, cost_lr=1e-2, action_lr=1e-3):
irl_loss_on_train = []
irl_loss_on_test = []
learnable_cost_opt = torch.optim.Adam(learnable_cost.parameters(), lr=cost_lr)
irl_loss_dems = []
# initial loss before training
plots_dir = os.path.join(model_data_dir, data_type, cost_type)
if not os.path.exists(plots_dir):
os.makedirs(plots_dir)
for demo_i in range(len(train_trajs)):
expert_demo_dict = train_trajs[demo_i]
start_pose = expert_demo_dict['start_joint_config'].squeeze()
expert_demo = expert_demo_dict['desired_keypoints'].reshape(traj_len, -1)
expert_demo = torch.Tensor(expert_demo)
time_horizon, n_keypt_dim = expert_demo.shape
keypoint_mpc_wrapper = GroundTruthKeypointMPCWrapper(robot_model, time_horizon=time_horizon - 1, n_keypt_dim=n_keypt_dim)
# unroll and extract expected features
pred_traj = keypoint_mpc_wrapper.roll_out(start_pose.clone())
# get initial irl loss
irl_loss = irl_loss_fn(pred_traj, expert_demo).mean()
irl_loss_dems.append(irl_loss.item())
irl_loss_on_train.append(torch.Tensor(irl_loss_dems).mean())
print("irl cost training iter: {} loss: {}".format(0, irl_loss_on_train[-1]))
print("Cost function parameters to be optimized:")
for name, param in learnable_cost.named_parameters():
print(name)
print(param)
# start of inverse RL loop
for outer_i in range(n_outer_iter):
irl_loss_dems = []
for demo_i in range(len(train_trajs)):
learnable_cost_opt.zero_grad()
expert_demo_dict = train_trajs[demo_i]
start_pose = expert_demo_dict['start_joint_config'].squeeze()
expert_demo = expert_demo_dict['desired_keypoints'].reshape(traj_len, -1)
expert_demo = torch.Tensor(expert_demo)
time_horizon, n_keypt_dim = expert_demo.shape
keypoint_mpc_wrapper = GroundTruthKeypointMPCWrapper(robot_model, time_horizon=time_horizon - 1,
n_keypt_dim=n_keypt_dim)
action_optimizer = torch.optim.SGD(keypoint_mpc_wrapper.parameters(), lr=action_lr)
with higher.innerloop_ctx(keypoint_mpc_wrapper, action_optimizer) as (fpolicy, diffopt):
pred_traj = fpolicy.roll_out(start_pose.clone())
# use the learned loss to update the action sequence
learned_cost_val = learnable_cost(pred_traj, expert_demo[-1])
diffopt.step(learned_cost_val)
pred_traj = fpolicy.roll_out(start_pose)
# compute task loss
irl_loss = irl_loss_fn(pred_traj, expert_demo).mean()
# backprop gradient of learned cost parameters wrt irl loss
irl_loss.backward(retain_graph=True)
irl_loss_dems.append(irl_loss.detach())
learnable_cost_opt.step()
if outer_i % 25 == 0:
plt.figure()
plt.plot(pred_traj[:, 7].detach(), pred_traj[:, 9].detach(), 'o')
plt.plot(expert_demo[:, 0], expert_demo[:, 2], 'x')
plt.title("outer i: {}".format(outer_i))
plt.savefig(os.path.join(plots_dir, f'{demo_i}_{outer_i}.png'))
irl_loss_on_train.append(torch.Tensor(irl_loss_dems).mean())
test_irl_losses = evaluate_action_optimization(learnable_cost.eval(), robot_model, irl_loss_fn, test_trajs,
n_inner_iter)
print("irl loss (on train) training iter: {} loss: {}".format(outer_i + 1, irl_loss_on_train[-1]))
print("irl loss (on test) training iter: {} loss: {}".format(outer_i + 1, test_irl_losses.mean().item()))
print("")
irl_loss_on_test.append(test_irl_losses)
learnable_cost_params = {}
for name, param in learnable_cost.named_parameters():
learnable_cost_params[name] = param
if len(learnable_cost_params) == 0:
# For RBF Weighted Cost
for name, param in learnable_cost.weights_fn.named_parameters():
learnable_cost_params[name] = param
plt.figure()
plt.plot(pred_traj[:, 7].detach(), pred_traj[:, 9].detach(), 'o')
plt.plot(expert_demo[:, 0], expert_demo[:, 2], 'x')
plt.title("final")
plt.savefig(os.path.join(plots_dir, f'{demo_i}_final.png'))
return torch.stack(irl_loss_on_train), torch.stack(irl_loss_on_test), learnable_cost_params, pred_traj