in reagent/evaluation/world_model_evaluator.py [0:0]
def evaluate(self, batch: MemoryNetworkInput):
"""Calculate feature importance: setting each state/action feature to
the mean value and observe loss increase."""
self.trainer.memory_network.mdnrnn.eval()
state_features = batch.state.float_features
action_features = batch.action
seq_len, batch_size, state_dim = state_features.size()
action_dim = action_features.size()[2]
action_feature_num = self.action_feature_num
state_feature_num = self.state_feature_num
feature_importance = torch.zeros(action_feature_num + state_feature_num)
orig_losses = self.trainer.get_loss(batch, state_dim=state_dim)
orig_loss = orig_losses["loss"].cpu().detach().item()
del orig_losses
action_feature_boundaries = self.sorted_action_feature_start_indices + [
action_dim
]
state_feature_boundaries = self.sorted_state_feature_start_indices + [state_dim]
for i in range(action_feature_num):
action_features = batch.action.reshape(
(batch_size * seq_len, action_dim)
).data.clone()
# if actions are discrete, an action's feature importance is the loss
# increase due to setting all actions to this action
if self.discrete_action:
assert action_dim == action_feature_num
action_vec = torch.zeros(action_dim)
action_vec[i] = 1
action_features[:] = action_vec
# if actions are continuous, an action's feature importance is the loss
# increase due to masking this action feature to its mean value
else:
boundary_start, boundary_end = (
action_feature_boundaries[i],
action_feature_boundaries[i + 1],
)
action_features[
:, boundary_start:boundary_end
] = self.compute_median_feature_value(
action_features[:, boundary_start:boundary_end]
)
action_features = action_features.reshape((seq_len, batch_size, action_dim))
new_batch = MemoryNetworkInput(
state=batch.state,
action=action_features,
next_state=batch.next_state,
reward=batch.reward,
time_diff=torch.ones_like(batch.reward).float(),
not_terminal=batch.not_terminal,
step=None,
)
losses = self.trainer.get_loss(new_batch, state_dim=state_dim)
feature_importance[i] = losses["loss"].cpu().detach().item() - orig_loss
del losses
for i in range(state_feature_num):
state_features = batch.state.float_features.reshape(
(batch_size * seq_len, state_dim)
).data.clone()
boundary_start, boundary_end = (
state_feature_boundaries[i],
state_feature_boundaries[i + 1],
)
state_features[
:, boundary_start:boundary_end
] = self.compute_median_feature_value(
state_features[:, boundary_start:boundary_end]
)
state_features = state_features.reshape((seq_len, batch_size, state_dim))
new_batch = MemoryNetworkInput(
state=FeatureData(float_features=state_features),
action=batch.action,
next_state=batch.next_state,
reward=batch.reward,
time_diff=torch.ones_like(batch.reward).float(),
not_terminal=batch.not_terminal,
step=None,
)
losses = self.trainer.get_loss(new_batch, state_dim=state_dim)
feature_importance[i + action_feature_num] = (
losses["loss"].cpu().detach().item() - orig_loss
)
del losses
self.trainer.memory_network.mdnrnn.train()
logger.info(
"**** Debug tool feature importance ****: {}".format(feature_importance)
)
return {"feature_loss_increase": feature_importance.numpy()}