in crypten/mpc/primitives/arithmetic.py [0:0]
def div_(self, y):
"""Divide two tensors element-wise"""
# TODO: Add test coverage for this code path (next 4 lines)
if isinstance(y, float) and int(y) == y:
y = int(y)
if is_float_tensor(y) and y.frac().eq(0).all():
y = y.long()
if isinstance(y, int) or is_int_tensor(y):
validate = cfg.debug.validation_mode
if validate:
tolerance = 1.0
tensor = self.get_plain_text()
# Truncate protocol for dividing by public integers:
if comm.get().get_world_size() > 2:
protocol = globals()[cfg.mpc.protocol]
protocol.truncate(self, y)
else:
self.share = self.share.div_(y, rounding_mode="trunc")
# Validate
if validate:
if not torch.lt(
torch.abs(self.get_plain_text() * y - tensor), tolerance
).all():
raise ValueError("Final result of division is incorrect.")
return self
# Otherwise multiply by reciprocal
if isinstance(y, float):
y = torch.tensor([y], dtype=torch.float, device=self.device)
assert is_float_tensor(y), "Unsupported type for div_: %s" % type(y)
return self.mul_(y.reciprocal())