def jacrev()

in functorch/_src/eager_transforms.py [0:0]


def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False):
    """
    Computes the Jacobian of :attr:`func` with respect to the arg(s) at index
    :attr:`argnum` using reverse mode autodiff

    Args:
        func (function): A Python function that takes one or more arguments,
            one of which must be a Tensor, and returns one or more Tensors
        argnums (int or Tuple[int]): Optional, integer or tuple of integers,
            saying which arguments to get the Jacobian with respect to.
            Default: 0.
        has_aux (bool): Flag indicating that :attr:`func` returns a
            ``(output, aux)`` tuple where the first element is the output of
            the function to be differentiated and the second element is
            auxiliary objects that will not be differentiated.
            Default: False.

    Returns:
        Returns a function that takes in the same inputs as :attr:`func` and
        returns the Jacobian of :attr:`func` with respect to the arg(s) at
        :attr:`argnums`. If ``has_aux is True``, then the returned function
        instead returns a ``(jacobian, aux)`` tuple where ``jacobian``
        is the Jacobian and ``aux`` is auxiliary objects returned by ``func``.

    A basic usage with a pointwise, unary operation will give a diagonal array
    as the Jacobian

        >>> from functorch import jacrev
        >>> x = torch.randn(5)
        >>> jacobian = jacrev(torch.sin)(x)
        >>> expected = torch.diag(torch.cos(x))
        >>> assert torch.allclose(jacobian, expected)

    :func:`jacrev` can be composed with vmap to produce batched
    Jacobians:

        >>> from functorch import jacrev, vmap
        >>> x = torch.randn(64, 5)
        >>> jacobian = vmap(jacrev(torch.sin))(x)
        >>> assert jacobian.shape == (64, 5, 5)

    Additionally, :func:`jacrev` can be composed with itself to produce
    Hessians

        >>> from functorch import jacrev
        >>> def f(x):
        >>>   return x.sin().sum()
        >>>
        >>> x = torch.randn(5)
        >>> hessian = jacrev(jacrev(f))(x)
        >>> assert torch.allclose(hessian, torch.diag(-x.sin()))

    By default, :func:`jacrev` computes the Jacobian with respect to the first
    input. However, it can compute the Jacboian with respect to a different
    argument by using :attr:`argnums`:

        >>> from functorch import jacrev
        >>> def f(x, y):
        >>>   return x + y ** 2
        >>>
        >>> x, y = torch.randn(5), torch.randn(5)
        >>> jacobian = jacrev(f, argnums=1)(x, y)
        >>> expected = torch.diag(2 * y)
        >>> assert torch.allclose(jacobian, expected)

    Additionally, passing a tuple to :attr:`argnums` will compute the Jacobian
    with respect to multiple arguments

        >>> from functorch import jacrev
        >>> def f(x, y):
        >>>   return x + y ** 2
        >>>
        >>> x, y = torch.randn(5), torch.randn(5)
        >>> jacobian = jacrev(f, argnums=(0, 1))(x, y)
        >>> expectedX = torch.diag(torch.ones_like(x))
        >>> expectedY = torch.diag(2 * y)
        >>> assert torch.allclose(jacobian[0], expectedX)
        >>> assert torch.allclose(jacobian[1], expectedY)

    .. note::
        Using PyTorch ``torch.no_grad`` together with ``jacrev``.
        Case 1: Using ``torch.no_grad`` inside a function:

            >>> def f(x):
            >>>     with torch.no_grad():
            >>>         c = x ** 2
            >>>     return x - c

        In this case, ``jacrev(f)(x)`` will respect the inner ``torch.no_grad``.

        Case 2: Using ``jacrev`` inside ``torch.no_grad`` context manager:

            >>> with torch.no_grad():
            >>>     jacrev(f)(x)

        In this case, ``jacrev`` will respect the inner ``torch.no_grad``, but not the
        outer one. This is because ``jacrev`` is a "function transform": its result
        should not depend on the result of a context manager outside of ``f``.
    """
    @wraps(func)
    def wrapper_fn(*args):
        f_wrapper, primals = _argnums_partial(func, args, argnums)
        vjp_out = vjp(f_wrapper, *primals, has_aux=has_aux)
        if has_aux:
            output, vjp_fn, aux = vjp_out
        else:
            output, vjp_fn = vjp_out

        # See NOTE: [Computing jacobian with vmap and vjp for multiple outputs]
        flat_output, output_spec = tree_flatten(output)
        if len(flat_output) == 0:
            raise RuntimeError(
                'jacrev(f, ...)(*args): expected f to return at least one Tensor, '
                'got no Tensors.')
        for out in flat_output:
            if isinstance(out, torch.Tensor):
                continue
            raise RuntimeError(
                'jacrev(f, ...)(*args): expected f to only return Tensors as '
                f'outputs, got {type(out)}')
        # NB: vjp already checks that all outputs are tensors
        # Step 1: Construct grad_outputs by splitting the standard basis
        flat_output_numels = tuple(out.numel() for out in flat_output)
        flat_basis = _construct_standard_basis_for(flat_output, flat_output_numels)
        basis = tree_unflatten(flat_basis, output_spec)

        results = vmap(vjp_fn)(basis)

        flat_primals, primals_spec = tree_flatten(primals)
        flat_results, results_spec = tree_flatten(results)

        # Step 2: The returned jacobian is one big tensor per input. In this step,
        # we split each Tensor by output.
        flat_results = [result.split(flat_output_numels, dim=0) for result in flat_results]
        flat_input_flat_output = [
            tuple(split.view(out.shape + primal.shape)
                  for split, out in zip(splits, flat_output))
            for splits, primal in zip(flat_results, flat_primals)
        ]

        # Step 3: Right now, `jacobian` is a List[List[Tensor]].
        # The outer List corresponds to the number of primals,
        # the inner List corresponds to the number of outputs.
        # We need to:
        # a. Exchange the order of the outer List and inner List
        # b. tree_unflatten the inner Lists (which correspond to the primals)
        # c. handle the argnums=int case
        # d. tree_unflatten the outer List (which corresponds to the outputs)
        flat_output_flat_input = tuple(zip(*flat_input_flat_output))

        flat_output_input = tuple(tree_unflatten(flat_input, primals_spec)
                                  for flat_input in flat_output_flat_input)

        if isinstance(argnums, int):
            flat_output_input = tuple(_safe_zero_index(flat_input)
                                      for flat_input in flat_output_input)
        output_input = tree_unflatten(flat_output_input, output_spec)
        if has_aux:
            return output_input, aux
        return output_input
    return wrapper_fn