neuron_explainer/activations/derived_scalars/utils.py (3 lines of code) (raw):

import torch def detach_and_clone(x: torch.Tensor, requires_grad: bool) -> torch.Tensor: """In some cases, a derived scalar may be computed by applying a function to some activations, and running .backward() on the output, with some tensors desired to be backprop'ed through and some not. This function is for that: it detaches and clones the input tensor such that it doesn't interfere with other places those activations are used, and so that the gradient information is cleared. It then sets requires_grad to the desired value based on whether this activation should be backprop'ed through.""" return x.detach().clone().requires_grad_(requires_grad)