in functorch/_src/eager_transforms.py [0:0]
def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False):
"""
Computes the Jacobian of :attr:`func` with respect to the arg(s) at index
:attr:`argnum` using reverse mode autodiff
Args:
func (function): A Python function that takes one or more arguments,
one of which must be a Tensor, and returns one or more Tensors
argnums (int or Tuple[int]): Optional, integer or tuple of integers,
saying which arguments to get the Jacobian with respect to.
Default: 0.
has_aux (bool): Flag indicating that :attr:`func` returns a
``(output, aux)`` tuple where the first element is the output of
the function to be differentiated and the second element is
auxiliary objects that will not be differentiated.
Default: False.
Returns:
Returns a function that takes in the same inputs as :attr:`func` and
returns the Jacobian of :attr:`func` with respect to the arg(s) at
:attr:`argnums`. If ``has_aux is True``, then the returned function
instead returns a ``(jacobian, aux)`` tuple where ``jacobian``
is the Jacobian and ``aux`` is auxiliary objects returned by ``func``.
A basic usage with a pointwise, unary operation will give a diagonal array
as the Jacobian
>>> from functorch import jacrev
>>> x = torch.randn(5)
>>> jacobian = jacrev(torch.sin)(x)
>>> expected = torch.diag(torch.cos(x))
>>> assert torch.allclose(jacobian, expected)
:func:`jacrev` can be composed with vmap to produce batched
Jacobians:
>>> from functorch import jacrev, vmap
>>> x = torch.randn(64, 5)
>>> jacobian = vmap(jacrev(torch.sin))(x)
>>> assert jacobian.shape == (64, 5, 5)
Additionally, :func:`jacrev` can be composed with itself to produce
Hessians
>>> from functorch import jacrev
>>> def f(x):
>>> return x.sin().sum()
>>>
>>> x = torch.randn(5)
>>> hessian = jacrev(jacrev(f))(x)
>>> assert torch.allclose(hessian, torch.diag(-x.sin()))
By default, :func:`jacrev` computes the Jacobian with respect to the first
input. However, it can compute the Jacboian with respect to a different
argument by using :attr:`argnums`:
>>> from functorch import jacrev
>>> def f(x, y):
>>> return x + y ** 2
>>>
>>> x, y = torch.randn(5), torch.randn(5)
>>> jacobian = jacrev(f, argnums=1)(x, y)
>>> expected = torch.diag(2 * y)
>>> assert torch.allclose(jacobian, expected)
Additionally, passing a tuple to :attr:`argnums` will compute the Jacobian
with respect to multiple arguments
>>> from functorch import jacrev
>>> def f(x, y):
>>> return x + y ** 2
>>>
>>> x, y = torch.randn(5), torch.randn(5)
>>> jacobian = jacrev(f, argnums=(0, 1))(x, y)
>>> expectedX = torch.diag(torch.ones_like(x))
>>> expectedY = torch.diag(2 * y)
>>> assert torch.allclose(jacobian[0], expectedX)
>>> assert torch.allclose(jacobian[1], expectedY)
.. note::
Using PyTorch ``torch.no_grad`` together with ``jacrev``.
Case 1: Using ``torch.no_grad`` inside a function:
>>> def f(x):
>>> with torch.no_grad():
>>> c = x ** 2
>>> return x - c
In this case, ``jacrev(f)(x)`` will respect the inner ``torch.no_grad``.
Case 2: Using ``jacrev`` inside ``torch.no_grad`` context manager:
>>> with torch.no_grad():
>>> jacrev(f)(x)
In this case, ``jacrev`` will respect the inner ``torch.no_grad``, but not the
outer one. This is because ``jacrev`` is a "function transform": its result
should not depend on the result of a context manager outside of ``f``.
"""
@wraps(func)
def wrapper_fn(*args):
f_wrapper, primals = _argnums_partial(func, args, argnums)
vjp_out = vjp(f_wrapper, *primals, has_aux=has_aux)
if has_aux:
output, vjp_fn, aux = vjp_out
else:
output, vjp_fn = vjp_out
# See NOTE: [Computing jacobian with vmap and vjp for multiple outputs]
flat_output, output_spec = tree_flatten(output)
if len(flat_output) == 0:
raise RuntimeError(
'jacrev(f, ...)(*args): expected f to return at least one Tensor, '
'got no Tensors.')
for out in flat_output:
if isinstance(out, torch.Tensor):
continue
raise RuntimeError(
'jacrev(f, ...)(*args): expected f to only return Tensors as '
f'outputs, got {type(out)}')
# NB: vjp already checks that all outputs are tensors
# Step 1: Construct grad_outputs by splitting the standard basis
flat_output_numels = tuple(out.numel() for out in flat_output)
flat_basis = _construct_standard_basis_for(flat_output, flat_output_numels)
basis = tree_unflatten(flat_basis, output_spec)
results = vmap(vjp_fn)(basis)
flat_primals, primals_spec = tree_flatten(primals)
flat_results, results_spec = tree_flatten(results)
# Step 2: The returned jacobian is one big tensor per input. In this step,
# we split each Tensor by output.
flat_results = [result.split(flat_output_numels, dim=0) for result in flat_results]
flat_input_flat_output = [
tuple(split.view(out.shape + primal.shape)
for split, out in zip(splits, flat_output))
for splits, primal in zip(flat_results, flat_primals)
]
# Step 3: Right now, `jacobian` is a List[List[Tensor]].
# The outer List corresponds to the number of primals,
# the inner List corresponds to the number of outputs.
# We need to:
# a. Exchange the order of the outer List and inner List
# b. tree_unflatten the inner Lists (which correspond to the primals)
# c. handle the argnums=int case
# d. tree_unflatten the outer List (which corresponds to the outputs)
flat_output_flat_input = tuple(zip(*flat_input_flat_output))
flat_output_input = tuple(tree_unflatten(flat_input, primals_spec)
for flat_input in flat_output_flat_input)
if isinstance(argnums, int):
flat_output_input = tuple(_safe_zero_index(flat_input)
for flat_input in flat_output_input)
output_input = tree_unflatten(flat_output_input, output_spec)
if has_aux:
return output_input, aux
return output_input
return wrapper_fn