def forward()

in shap_e/models/nn/checkpoint.py [0:0]


    def forward(ctx, run_function, length_1, length_2, *args):
        ctx.run_function = run_function
        ctx.length_1 = length_1
        ctx.length_2 = length_2
        input_tensors = [x.detach().requires_grad_(True) for x in args[:length_1]]
        input_params = list(args[length_1 : length_1 + length_2])
        output_grads = list(args[length_1 + length_2 :])
        ctx.save_for_backward(*input_tensors, *input_params, *output_grads)

        with torch.enable_grad():
            # Fixes a bug where the first op in run_function modifies the
            # Tensor storage in place, which is not allowed for detach()'d
            # Tensors.
            shallow_copies = [x.view_as(x) for x in input_tensors]
            output_tensors = ctx.run_function(*shallow_copies)
        input_grads = torch.autograd.grad(
            output_tensors,
            input_tensors + input_params,
            output_grads,
            allow_unused=True,
        )
        return input_grads