in captum/attr/_core/deep_lift.py [0:0]
def _hook_main_model(self) -> List[RemovableHandle]:
def pre_hook(module: Module, baseline_inputs_add_args: Tuple) -> Tuple:
inputs = baseline_inputs_add_args[0]
baselines = baseline_inputs_add_args[1]
additional_args = None
if len(baseline_inputs_add_args) > 2:
additional_args = baseline_inputs_add_args[2:]
baseline_input_tsr = tuple(
torch.cat([input, baseline])
for input, baseline in zip(inputs, baselines)
)
if additional_args is not None:
expanded_additional_args = cast(
Tuple,
_expand_additional_forward_args(
additional_args, 2, ExpansionTypes.repeat
),
)
return (*baseline_input_tsr, *expanded_additional_args)
return baseline_input_tsr
def forward_hook(module: Module, inputs: Tuple, outputs: Tensor):
return torch.stack(torch.chunk(outputs, 2), dim=1)
if isinstance(
self.model, (nn.DataParallel, nn.parallel.DistributedDataParallel)
):
return [
self.model.module.register_forward_pre_hook(pre_hook), # type: ignore
self.model.module.register_forward_hook(forward_hook),
] # type: ignore
else:
return [
self.model.register_forward_pre_hook(pre_hook), # type: ignore
self.model.register_forward_hook(forward_hook),
] # type: ignore