def evaluate()

in reagent/evaluation/world_model_evaluator.py [0:0]


    def evaluate(self, batch: MemoryNetworkInput):
        """Calculate feature importance: setting each state/action feature to
        the mean value and observe loss increase."""

        self.trainer.memory_network.mdnrnn.eval()
        state_features = batch.state.float_features
        action_features = batch.action
        seq_len, batch_size, state_dim = state_features.size()
        action_dim = action_features.size()[2]
        action_feature_num = self.action_feature_num
        state_feature_num = self.state_feature_num
        feature_importance = torch.zeros(action_feature_num + state_feature_num)

        orig_losses = self.trainer.get_loss(batch, state_dim=state_dim)
        orig_loss = orig_losses["loss"].cpu().detach().item()
        del orig_losses

        action_feature_boundaries = self.sorted_action_feature_start_indices + [
            action_dim
        ]
        state_feature_boundaries = self.sorted_state_feature_start_indices + [state_dim]

        for i in range(action_feature_num):
            action_features = batch.action.reshape(
                (batch_size * seq_len, action_dim)
            ).data.clone()

            # if actions are discrete, an action's feature importance is the loss
            # increase due to setting all actions to this action
            if self.discrete_action:
                assert action_dim == action_feature_num
                action_vec = torch.zeros(action_dim)
                action_vec[i] = 1
                action_features[:] = action_vec
            # if actions are continuous, an action's feature importance is the loss
            # increase due to masking this action feature to its mean value
            else:
                boundary_start, boundary_end = (
                    action_feature_boundaries[i],
                    action_feature_boundaries[i + 1],
                )
                action_features[
                    :, boundary_start:boundary_end
                ] = self.compute_median_feature_value(
                    action_features[:, boundary_start:boundary_end]
                )

            action_features = action_features.reshape((seq_len, batch_size, action_dim))
            new_batch = MemoryNetworkInput(
                state=batch.state,
                action=action_features,
                next_state=batch.next_state,
                reward=batch.reward,
                time_diff=torch.ones_like(batch.reward).float(),
                not_terminal=batch.not_terminal,
                step=None,
            )
            losses = self.trainer.get_loss(new_batch, state_dim=state_dim)
            feature_importance[i] = losses["loss"].cpu().detach().item() - orig_loss
            del losses

        for i in range(state_feature_num):
            state_features = batch.state.float_features.reshape(
                (batch_size * seq_len, state_dim)
            ).data.clone()
            boundary_start, boundary_end = (
                state_feature_boundaries[i],
                state_feature_boundaries[i + 1],
            )
            state_features[
                :, boundary_start:boundary_end
            ] = self.compute_median_feature_value(
                state_features[:, boundary_start:boundary_end]
            )
            state_features = state_features.reshape((seq_len, batch_size, state_dim))
            new_batch = MemoryNetworkInput(
                state=FeatureData(float_features=state_features),
                action=batch.action,
                next_state=batch.next_state,
                reward=batch.reward,
                time_diff=torch.ones_like(batch.reward).float(),
                not_terminal=batch.not_terminal,
                step=None,
            )
            losses = self.trainer.get_loss(new_batch, state_dim=state_dim)
            feature_importance[i + action_feature_num] = (
                losses["loss"].cpu().detach().item() - orig_loss
            )
            del losses

        self.trainer.memory_network.mdnrnn.train()
        logger.info(
            "**** Debug tool feature importance ****: {}".format(feature_importance)
        )
        return {"feature_loss_increase": feature_importance.numpy()}