in crypten/cryptensor.py [0:0]
def _get_autograd_forward_function(self, name, grad_fn, in_place):
# determine if self is a dummy object (the case for staticmethods):
is_dummy = getattr(self, "__IS_DUMMY__", False)
def autograd_forward(*args, **kwargs):
"""Forward function that stores data for autograd in result."""
with CrypTensor.no_grad():
# only CrypTensors can be children:
tensor_args = _find_all_cryptensors(args)
children = tensor_args if is_dummy else [self, *tensor_args]
# identify whether result requires gradient:
requires_grad = any(child.requires_grad for child in children)
if not requires_grad:
return self.__getattribute__(name)(*args, **kwargs)
# in-place functions are not supported when requires_grad:
if in_place:
raise RuntimeError("Cannot use in-place functions with autograd.")
# prepare inputs and context for forward call:
ctx = AutogradContext()
if not is_dummy:
args = [self] + list(args)
# apply correct autograd function:
result = grad_fn.forward(ctx, *args, **kwargs)
# output may be tensor or tuple
if not isinstance(result, tuple):
result = (result,)
remove_tuple = True
else:
remove_tuple = False
# maintain references to children and context in result:
for res in result:
res.requires_grad = ctx.is_differentiable(res)
if res.requires_grad:
res.children = children
res.grad_fn = grad_fn
res.ctx = ctx
# return result:
if remove_tuple:
result = result[0]
return result
return autograd_forward