in crypten/__init__.py [0:0]
def _setup_prng():
"""
Generate shared random seeds to generate pseudo-random sharings of
zero. For each device, we generator four random seeds:
"prev" - shared seed with the previous party
"next" - shared seed with the next party
"local" - seed known only to the local party (separate from torch's default seed to prevent interference from torch.manual_seed)
"global"- seed shared by all parties
The "prev" and "next" random seeds are shared such that each process shares
one seed with the previous rank process and one with the next rank.
This allows for the generation of `n` random values, each known to
exactly two of the `n` parties.
For arithmetic sharing, one of these parties will add the number
while the other subtracts it, allowing for the generation of a
pseudo-random sharing of zero. (This can be done for binary
sharing using bitwise-xor rather than addition / subtraction)
"""
global generators
# Initialize RNG Generators
for key in generators.keys():
generators[key][torch.device("cpu")] = torch.Generator(
device=torch.device("cpu")
)
if torch.cuda.is_available():
cuda_device_names = ["cuda"]
for i in range(torch.cuda.device_count()):
cuda_device_names.append(f"cuda:{i}")
cuda_devices = [torch.device(name) for name in cuda_device_names]
for device in cuda_devices:
for key in generators.keys():
generators[key][device] = torch.Generator(device=device)
# Generate random seeds for Generators
# NOTE: Chosen seed can be any number, but we choose as a random 64-bit
# integer here so other parties cannot guess its value. We use os.urandom(8)
# here to generate seeds so that forked processes do not generate the same seed.
# Generate next / prev seeds.
seed = int.from_bytes(os.urandom(8), "big") - 2 ** 63
next_seed = torch.tensor(seed)
prev_seed = torch.tensor([0], dtype=torch.long) # populated by irecv
# Send random seed to next party, receive random seed from prev party
world_size = comm.get().get_world_size()
rank = comm.get().get_rank()
if world_size >= 2: # Guard against segfaults when world_size == 1.
next_rank = (rank + 1) % world_size
prev_rank = (next_rank - 2) % world_size
req0 = comm.get().isend(next_seed, next_rank)
req1 = comm.get().irecv(prev_seed, src=prev_rank)
req0.wait()
req1.wait()
else:
prev_seed = next_seed
prev_seed = prev_seed.item()
next_seed = next_seed.item()
# Create local seed - Each party has a separate local generator
local_seed = int.from_bytes(os.urandom(8), "big") - 2 ** 63
# Create global generator - All parties share one global generator for sync'd rng
global_seed = int.from_bytes(os.urandom(8), "big") - 2 ** 63
global_seed = torch.tensor(global_seed)
global_seed = comm.get().broadcast(global_seed, 0).item()
# Create one of each seed per party
# Note: This is configured to coordinate seeds across cuda devices
# so that we can one party per gpu. If we want to support configurations
# where each party runs on multiple gpu's across machines, we will
# need to modify this.
for device in generators["prev"].keys():
generators["prev"][device].manual_seed(prev_seed)
generators["next"][device].manual_seed(next_seed)
generators["local"][device].manual_seed(local_seed)
generators["global"][device].manual_seed(global_seed)