in rlkit/core/batch_rl_algorithm.py [0:0]
def _visualize(self, policy=False, q_function=False, num_dir=50, alpha=0.1, iter=None):
assert policy or q_function, "Both are false, need something to visualize"
# import ipdb; ipdb.set_trace()
policy_weights = get_flat_params(self.trainer.policy)
# qf1_weights = get_flat_params(self.trainer.qf1)
# qf2_weights = get_flat_params(self.trainer.qf2)
policy_dim = policy_weights.shape[0]
# qf_dim = qf1_weights.shape[0]
# Create clones to assign weights
policy_clone = copy.deepcopy(self.trainer.policy)
# Create arrays for storing data
q1_plus_eval = []
q1_minus_eval = []
q2_plus_eval = []
q2_minus_eval = []
qmin_plus_eval = []
qmin_minus_eval = []
returns_plus_eval = []
returns_minus_eval = []
# Groundtruth policy params
policy_eval_qf1 = self._eval_q_custom_policy(self.trainer.policy, self.trainer.qf1)
policy_eval_qf2 = self._eval_q_custom_policy(self.trainer.policy, self.trainer.qf2)
policy_eval_q_min = min(policy_eval_qf1, policy_eval_qf2)
policy_eval_returns = self.eval_policy_custom(self.trainer.policy)
# These are the policy saddle point detection
for idx in range(num_dir):
random_dir = np.random.normal(size=(policy_dim))
theta_plus = policy_weights + alpha * policy_dim
theta_minus = policy_weights - alpha * policy_dim
set_flat_params(policy_clone, theta_plus)
q_plus_1 = self._eval_q_custom_policy(policy_clone, self.trainer.qf1)
q_plus_2 = self._eval_q_custom_policy(policy_clone, self.trainer.qf2)
q_plus_min = min(q_plus_1, q_plus_2)
eval_return_plus = self.eval_policy_custom(policy_clone)
set_flat_params(policy_clone, theta_minus)
q_minus_1 = self._eval_q_custom_policy(policy_clone, self.trainer.qf1)
q_minus_2 = self._eval_q_custom_policy(policy_clone, self.trainer.qf2)
q_minus_min = min(q_minus_1, q_minus_2)
eval_return_minus = self.eval_policy_custom(policy_clone)
# Append to array
q1_plus_eval.append(q_plus_1)
q2_plus_eval.append(q_plus_2)
q1_minus_eval.append(q_minus_1)
q2_minus_eval.append(q_minus_2)
qmin_plus_eval.append(q_plus_min)
qmin_minus_eval.append(q_minus_min)
returns_plus_eval.append(eval_return_plus)
returns_minus_eval.append(eval_return_minus)
# Now we visualize
# import ipdb; ipdb.set_trace()
q1_plus_eval = np.array(q1_plus_eval)
q1_minus_eval = np.array(q1_minus_eval)
q2_plus_eval = np.array(q2_plus_eval)
q2_minus_eval = np.array(q2_minus_eval)
qmin_plus_eval = np.array(qmin_plus_eval)
qmin_minus_eval = np.array(qmin_minus_eval)
returns_plus_eval = np.array(returns_plus_eval)
returns_minus_eval = np.array(returns_minus_eval)
self.plot_visualized_data(q1_plus_eval, q1_minus_eval, policy_eval_qf1, fig_label='q1_policy_params_iter_' + (str(iter)))
self.plot_visualized_data(q2_plus_eval, q2_minus_eval, policy_eval_qf2, fig_label='q2_policy_params_iter_' + (str(iter)))
self.plot_visualized_data(qmin_plus_eval, qmin_minus_eval, policy_eval_q_min, fig_label='qmin_policy_params_iter_' + (str(iter)))
self.plot_visualized_data(returns_plus_eval, returns_minus_eval, policy_eval_returns, fig_label='returns_policy_params_iter_' + (str(iter)))
del policy_clone