in crypten/mpc/context.py [0:0]
def run_multiprocess(world_size):
"""Defines decorator to run function across multiple processes
Args:
world_size (int): number of parties / processes to initiate.
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
rendezvous_file = tempfile.NamedTemporaryFile(delete=True).name
queue = multiprocessing.Queue()
processes = [
multiprocessing.Process(
target=_launch,
args=(func, rank, world_size, rendezvous_file, queue, args, kwargs),
)
for rank in range(world_size)
]
# Initialize TTP process
if crypten.mpc.ttp_required():
processes += [
multiprocessing.Process(
target=_launch,
args=(
crypten.mpc.provider.TTPServer,
world_size,
world_size,
rendezvous_file,
queue,
(),
{},
),
)
]
# This process will be forked and we need to re-initialize the
# communicator in the children. If the parent process happened to
# call crypten.init(), which might be valid in a Jupyter notebook
# for instance, then the crypten.init() call on the children
# process will not do anything. The call to uninit here makes sure
# we actually get to initialize the communicator on the child
# process. An alternative fix for this issue would be to use spawn
# instead of fork, but we run into issues serializing the function
# in that case.
was_initialized = DistributedCommunicator.is_initialized()
if was_initialized:
crypten.uninit()
for process in processes:
process.start()
for process in processes:
process.join()
if was_initialized:
crypten.init()
successful = [process.exitcode == 0 for process in processes]
if not all(successful):
logging.error("One of the parties failed. Check past logs")
return None
return_values = []
while not queue.empty():
return_values.append(queue.get())
return [value for _, value in sorted(return_values, key=itemgetter(0))]
return wrapper
return decorator