in slowfast/visualization/gradcam_utils.py [0:0]
def _calculate_localization_map(self, inputs, labels=None):
"""
Calculate localization map for all inputs with Grad-CAM.
Args:
inputs (list of tensor(s)): the input clips.
labels (Optional[tensor]): labels of the current input clips.
Returns:
localization_maps (list of ndarray(s)): the localization map for
each corresponding input.
preds (tensor): shape (n_instances, n_class). Model predictions for `inputs`.
"""
assert len(inputs) == len(
self.target_layers
), "Must register the same number of target layers as the number of input pathways."
input_clone = [inp.clone() for inp in inputs]
preds = self.model(input_clone)
if labels is None:
score = torch.max(preds, dim=-1)[0]
else:
if labels.ndim == 1:
labels = labels.unsqueeze(-1)
score = torch.gather(preds, dim=1, index=labels)
self.model.zero_grad()
score = torch.sum(score)
score.backward()
localization_maps = []
for i, inp in enumerate(inputs):
_, _, T, H, W = inp.size()
gradients = self.gradients[self.target_layers[i]]
activations = self.activations[self.target_layers[i]]
B, C, Tg, _, _ = gradients.size()
weights = torch.mean(gradients.view(B, C, Tg, -1), dim=3)
weights = weights.view(B, C, Tg, 1, 1)
localization_map = torch.sum(
weights * activations, dim=1, keepdim=True
)
localization_map = F.relu(localization_map)
localization_map = F.interpolate(
localization_map,
size=(T, H, W),
mode="trilinear",
align_corners=False,
)
localization_map_min, localization_map_max = (
torch.min(localization_map.view(B, -1), dim=-1, keepdim=True)[
0
],
torch.max(localization_map.view(B, -1), dim=-1, keepdim=True)[
0
],
)
localization_map_min = torch.reshape(
localization_map_min, shape=(B, 1, 1, 1, 1)
)
localization_map_max = torch.reshape(
localization_map_max, shape=(B, 1, 1, 1, 1)
)
# Normalize the localization map.
localization_map = (localization_map - localization_map_min) / (
localization_map_max - localization_map_min + 1e-6
)
localization_map = localization_map.data
localization_maps.append(localization_map)
return localization_maps, preds