in captum/attr/_core/layer/layer_feature_ablation.py [0:0]
def layer_forward_func(*args):
layer_length = args[-1]
layer_input = args[:layer_length]
original_inputs = args[layer_length:-1]
device_ids = self.device_ids
if device_ids is None:
device_ids = getattr(self.forward_func, "device_ids", None)
all_layer_inputs = {}
if device_ids is not None:
scattered_layer_input = scatter(layer_input, target_gpus=device_ids)
for device_tensors in scattered_layer_input:
all_layer_inputs[device_tensors[0].device] = device_tensors
else:
all_layer_inputs[layer_input[0].device] = layer_input
def forward_hook(module, inp, out=None):
device = _extract_device(module, inp, out)
is_layer_tuple = (
isinstance(out, tuple)
if out is not None
else isinstance(inp, tuple)
)
if device not in all_layer_inputs:
raise AssertionError(
"Layer input not placed on appropriate "
"device. If using a DataParallel model, either provide the "
"DataParallel model as forward_func or provide device ids"
" to the constructor."
)
if not is_layer_tuple:
return all_layer_inputs[device][0]
return all_layer_inputs[device]
hook = None
try:
if attribute_to_layer_input:
hook = self.layer.register_forward_pre_hook(forward_hook)
else:
hook = self.layer.register_forward_hook(forward_hook)
eval = _run_forward(self.forward_func, original_inputs, target=target)
finally:
if hook is not None:
hook.remove()
return eval