in crypten/mpc/primitives/arithmetic.py [0:0]
def _arithmetic_function(self, y, op, inplace=False, *args, **kwargs): # noqa:C901
assert op in [
"add",
"sub",
"mul",
"matmul",
"conv1d",
"conv2d",
"conv_transpose1d",
"conv_transpose2d",
], f"Provided op `{op}` is not a supported arithmetic function"
additive_func = op in ["add", "sub"]
public = isinstance(y, (int, float)) or is_tensor(y)
private = isinstance(y, ArithmeticSharedTensor)
if inplace:
result = self
if additive_func or (op == "mul" and public):
op += "_"
else:
result = self.clone()
if public:
y = result.encoder.encode(y, device=self.device)
if additive_func: # ['add', 'sub']
if result.rank == 0:
result.share = getattr(result.share, op)(y)
else:
result.share = torch.broadcast_tensors(result.share, y)[0]
elif op == "mul_": # ['mul_']
result.share = result.share.mul_(y)
else: # ['mul', 'matmul', 'convNd', 'conv_transposeNd']
result.share = getattr(torch, op)(result.share, y, *args, **kwargs)
elif private:
if additive_func: # ['add', 'sub', 'add_', 'sub_']
# Re-encode if necessary:
if self.encoder.scale > y.encoder.scale:
y.encode_as_(result)
elif self.encoder.scale < y.encoder.scale:
result.encode_as_(y)
result.share = getattr(result.share, op)(y.share)
else: # ['mul', 'matmul', 'convNd', 'conv_transposeNd']
protocol = globals()[cfg.mpc.protocol]
result.share.set_(
getattr(protocol, op)(result, y, *args, **kwargs).share.data
)
else:
raise TypeError("Cannot %s %s with %s" % (op, type(y), type(self)))
# Scale by encoder scale if necessary
if not additive_func:
if public: # scale by self.encoder.scale
if self.encoder.scale > 1:
return result.div_(result.encoder.scale)
else:
result.encoder = self.encoder
else: # scale by larger of self.encoder.scale and y.encoder.scale
if self.encoder.scale > 1 and y.encoder.scale > 1:
return result.div_(result.encoder.scale)
elif self.encoder.scale > 1:
result.encoder = self.encoder
else:
result.encoder = y.encoder
return result