def backward()

in shap_e/models/nn/checkpoint.py [0:0]


    def backward(ctx, *all_output_grads):
        args = ctx.saved_tensors
        input_tensors = [x.detach().requires_grad_(True) for x in args[: ctx.length_1]]
        input_params = list(args[ctx.length_1 : ctx.length_1 + ctx.length_2])
        output_grads = [
            x.detach().requires_grad_(True) for x in args[ctx.length_1 + ctx.length_2 :]
        ]

        with torch.enable_grad():
            # Fixes a bug where the first op in run_function modifies the
            # Tensor storage in place, which is not allowed for detach()'d
            # Tensors.
            shallow_copies = [x.view_as(x) for x in input_tensors]
            output_tensors = ctx.run_function(*shallow_copies)
            input_grads = torch.autograd.grad(
                output_tensors,
                input_tensors + input_params,
                output_grads,
                allow_unused=True,
                create_graph=True,
                retain_graph=True,
            )
        input_grads_grads = torch.autograd.grad(
            input_grads,
            input_tensors + input_params + output_grads,
            all_output_grads,
            allow_unused=True,
        )
        del input_grads
        return (None, None, None) + input_grads_grads