def _get_autograd_forward_function()

in crypten/cryptensor.py [0:0]


    def _get_autograd_forward_function(self, name, grad_fn, in_place):

        # determine if self is a dummy object (the case for staticmethods):
        is_dummy = getattr(self, "__IS_DUMMY__", False)

        def autograd_forward(*args, **kwargs):
            """Forward function that stores data for autograd in result."""
            with CrypTensor.no_grad():
                # only CrypTensors can be children:
                tensor_args = _find_all_cryptensors(args)
                children = tensor_args if is_dummy else [self, *tensor_args]

                # identify whether result requires gradient:
                requires_grad = any(child.requires_grad for child in children)

                if not requires_grad:
                    return self.__getattribute__(name)(*args, **kwargs)

                # in-place functions are not supported when requires_grad:
                if in_place:
                    raise RuntimeError("Cannot use in-place functions with autograd.")

                # prepare inputs and context for forward call:
                ctx = AutogradContext()
                if not is_dummy:
                    args = [self] + list(args)

                # apply correct autograd function:
                result = grad_fn.forward(ctx, *args, **kwargs)

                # output may be tensor or tuple
                if not isinstance(result, tuple):
                    result = (result,)
                    remove_tuple = True
                else:
                    remove_tuple = False

                # maintain references to children and context in result:
                for res in result:
                    res.requires_grad = ctx.is_differentiable(res)
                    if res.requires_grad:
                        res.children = children
                        res.grad_fn = grad_fn
                        res.ctx = ctx

                # return result:
                if remove_tuple:
                    result = result[0]
            return result

        return autograd_forward