def _hook_main_model()

in captum/attr/_core/deep_lift.py [0:0]


    def _hook_main_model(self) -> List[RemovableHandle]:
        def pre_hook(module: Module, baseline_inputs_add_args: Tuple) -> Tuple:
            inputs = baseline_inputs_add_args[0]
            baselines = baseline_inputs_add_args[1]
            additional_args = None
            if len(baseline_inputs_add_args) > 2:
                additional_args = baseline_inputs_add_args[2:]

            baseline_input_tsr = tuple(
                torch.cat([input, baseline])
                for input, baseline in zip(inputs, baselines)
            )
            if additional_args is not None:
                expanded_additional_args = cast(
                    Tuple,
                    _expand_additional_forward_args(
                        additional_args, 2, ExpansionTypes.repeat
                    ),
                )
                return (*baseline_input_tsr, *expanded_additional_args)
            return baseline_input_tsr

        def forward_hook(module: Module, inputs: Tuple, outputs: Tensor):
            return torch.stack(torch.chunk(outputs, 2), dim=1)

        if isinstance(
            self.model, (nn.DataParallel, nn.parallel.DistributedDataParallel)
        ):
            return [
                self.model.module.register_forward_pre_hook(pre_hook),  # type: ignore
                self.model.module.register_forward_hook(forward_hook),
            ]  # type: ignore
        else:
            return [
                self.model.register_forward_pre_hook(pre_hook),  # type: ignore
                self.model.register_forward_hook(forward_hook),
            ]  # type: ignore