def _visualize()

in rlkit/core/batch_rl_algorithm.py [0:0]


    def _visualize(self, policy=False, q_function=False, num_dir=50, alpha=0.1, iter=None):
        assert policy or q_function, "Both are false, need something to visualize"
        # import ipdb; ipdb.set_trace()
        policy_weights = get_flat_params(self.trainer.policy)
        # qf1_weights = get_flat_params(self.trainer.qf1)
        # qf2_weights = get_flat_params(self.trainer.qf2)
        
        policy_dim = policy_weights.shape[0]
        # qf_dim = qf1_weights.shape[0]
        
        # Create clones to assign weights
        policy_clone = copy.deepcopy(self.trainer.policy)

        # Create arrays for storing data
        q1_plus_eval = []
        q1_minus_eval = []
        q2_plus_eval = []
        q2_minus_eval = []
        qmin_plus_eval = []
        qmin_minus_eval = []
        returns_plus_eval = []
        returns_minus_eval = []

        # Groundtruth policy params
        policy_eval_qf1 = self._eval_q_custom_policy(self.trainer.policy, self.trainer.qf1)
        policy_eval_qf2 = self._eval_q_custom_policy(self.trainer.policy, self.trainer.qf2)
        policy_eval_q_min = min(policy_eval_qf1, policy_eval_qf2)
        policy_eval_returns = self.eval_policy_custom(self.trainer.policy)

        # These are the policy saddle point detection
        for idx in range(num_dir):
            random_dir = np.random.normal(size=(policy_dim))
            theta_plus = policy_weights + alpha * policy_dim
            theta_minus = policy_weights - alpha * policy_dim

            set_flat_params(policy_clone, theta_plus)
            q_plus_1 = self._eval_q_custom_policy(policy_clone, self.trainer.qf1)
            q_plus_2 = self._eval_q_custom_policy(policy_clone, self.trainer.qf2)
            q_plus_min = min(q_plus_1, q_plus_2)
            eval_return_plus = self.eval_policy_custom(policy_clone)

            set_flat_params(policy_clone, theta_minus)
            q_minus_1 = self._eval_q_custom_policy(policy_clone, self.trainer.qf1)
            q_minus_2 = self._eval_q_custom_policy(policy_clone, self.trainer.qf2)
            q_minus_min = min(q_minus_1, q_minus_2)
            eval_return_minus = self.eval_policy_custom(policy_clone)

            # Append to array
            q1_plus_eval.append(q_plus_1)
            q2_plus_eval.append(q_plus_2)
            q1_minus_eval.append(q_minus_1)
            q2_minus_eval.append(q_minus_2)
            qmin_plus_eval.append(q_plus_min)
            qmin_minus_eval.append(q_minus_min)
            returns_plus_eval.append(eval_return_plus)
            returns_minus_eval.append(eval_return_minus)
        
        # Now we visualize
        # import ipdb; ipdb.set_trace()

        q1_plus_eval = np.array(q1_plus_eval)
        q1_minus_eval = np.array(q1_minus_eval)
        q2_plus_eval = np.array(q2_plus_eval)
        q2_minus_eval = np.array(q2_minus_eval)
        qmin_plus_eval = np.array(qmin_plus_eval)
        qmin_minus_eval = np.array(qmin_minus_eval)
        returns_plus_eval = np.array(returns_plus_eval)
        returns_minus_eval = np.array(returns_minus_eval)

        self.plot_visualized_data(q1_plus_eval, q1_minus_eval, policy_eval_qf1, fig_label='q1_policy_params_iter_' + (str(iter)))
        self.plot_visualized_data(q2_plus_eval, q2_minus_eval, policy_eval_qf2, fig_label='q2_policy_params_iter_' + (str(iter)))
        self.plot_visualized_data(qmin_plus_eval, qmin_minus_eval, policy_eval_q_min, fig_label='qmin_policy_params_iter_' + (str(iter)))
        self.plot_visualized_data(returns_plus_eval, returns_minus_eval, policy_eval_returns, fig_label='returns_policy_params_iter_' + (str(iter)))

        del policy_clone