in crypten/mpc/primitives/beaver.py [0:0]
def __beaver_protocol(op, x, y, *args, **kwargs):
"""Performs Beaver protocol for additively secret-shared tensors x and y
1. Obtain uniformly random sharings [a],[b] and [c] = [a * b]
2. Additively hide [x] and [y] with appropriately sized [a] and [b]
3. Open ([epsilon] = [x] - [a]) and ([delta] = [y] - [b])
4. Return [z] = [c] + (epsilon * [b]) + ([a] * delta) + (epsilon * delta)
"""
assert op in {
"mul",
"matmul",
"conv1d",
"conv2d",
"conv_transpose1d",
"conv_transpose2d",
}
if x.device != y.device:
raise ValueError(f"x lives on device {x.device} but y on device {y.device}")
provider = crypten.mpc.get_default_provider()
a, b, c = provider.generate_additive_triple(
x.size(), y.size(), op, device=x.device, *args, **kwargs
)
from .arithmetic import ArithmeticSharedTensor
if cfg.mpc.active_security:
"""
Reference: "Multiparty Computation from Somewhat Homomorphic Encryption"
Link: https://eprint.iacr.org/2011/535.pdf
"""
f, g, h = provider.generate_additive_triple(
x.size(), y.size(), op, device=x.device, *args, **kwargs
)
t = ArithmeticSharedTensor.PRSS(a.size(), device=x.device)
t_plain_text = t.get_plain_text()
rho = (t_plain_text * a - f).get_plain_text()
sigma = (b - g).get_plain_text()
triples_check = t_plain_text * c - h - sigma * f - rho * g - rho * sigma
triples_check = triples_check.get_plain_text()
if torch.any(triples_check != 0):
raise ValueError("Beaver Triples verification failed!")
# Vectorized reveal to reduce rounds of communication
with IgnoreEncodings([a, b, x, y]):
epsilon, delta = ArithmeticSharedTensor.reveal_batch([x - a, y - b])
# z = c + (a * delta) + (epsilon * b) + epsilon * delta
c._tensor += getattr(torch, op)(epsilon, b._tensor, *args, **kwargs)
c._tensor += getattr(torch, op)(a._tensor, delta, *args, **kwargs)
c += getattr(torch, op)(epsilon, delta, *args, **kwargs)
return c