in crypten/communicator/communicator.py [0:0]
def _logging(func):
"""
Decorator that performs logging of communication statistics.
NOTE: Each party performs its own logging of communication, so one needs to
sum the number of bytes communicated over all parties and divide by two
(to prevent double-counting) to obtain the number of bytes communicated in
the overall system.
"""
from functools import wraps
@wraps(func)
def logging_wrapper(self, *args, **kwargs):
# TODO: Replace this
# - hacks the inputs into some of the functions for world_size 1:
world_size = self.get_world_size()
if world_size < 2:
if func.__name__ in ["gather", "all_gather"]:
return [args[0]]
elif len(args) > 0:
return args[0]
# only log communication if needed:
if cfg.communicator.verbose:
rank = self.get_rank()
_log = self._log_communication
# count number of bytes communicates for each MPI collective:
if func.__name__ == "barrier":
_log(0)
elif func.__name__ in ["send", "recv", "isend", "irecv"]:
_log(args[0].nelement()) # party sends or receives tensor
elif func.__name__ == "scatter":
if args[1] == rank: # party scatters P - 1 tensors
nelements = sum(
x.nelement() for idx, x in enumerate(args[0]) if idx != rank
)
_log(nelements) # NOTE: We deal with other parties later
elif func.__name__ == "all_gather":
_log(2 * (world_size - 1) * args[0].nelement())
# party sends and receives P - 1 tensors
elif func.__name__ == "send_obj":
nbytes = sys.getsizeof(args[0])
_log(nbytes / self.BYTES_PER_ELEMENT) # party sends object
elif func.__name__ == "broadcast_obj":
nbytes = sys.getsizeof(args[0])
_log(nbytes / self.BYTES_PER_ELEMENT * (world_size - 1))
# party sends object to P - 1 parties
elif func.__name__ in ["broadcast", "gather", "reduce"]:
multiplier = world_size - 1 if args[1] == rank else 1
# broadcast: party sends tensor to P - 1 parties, or receives 1 tensor
# gather: party receives P - 1 tensors, or sends 1 tensor
# reduce: party receives P - 1 tensors, or sends 1 tensor
if "batched" in kwargs and kwargs["batched"]:
nelements = sum(x.nelement() for x in args[0])
_log(nelements * multiplier)
else:
_log(args[0].nelement() * multiplier)
elif func.__name__ == "all_reduce":
# each party sends and receives one tensor in ring implementation
if "batched" in kwargs and kwargs["batched"]:
nelements = sum(2 * x.nelement() for x in args[0])
_log(nelements)
else:
_log(2 * args[0].nelement())
# execute and time the MPI collective:
tic = timeit.default_timer()
result = func(self, *args, **kwargs)
toc = timeit.default_timer()
self._log_communication_time(toc - tic)
# for some function, we only know the object size now:
if func.__name__ == "scatter" and args[1] != rank:
_log(result.nelement()) # party receives 1 tensor
if func.__name__ == "recv_obj":
_log(sys.getsizeof(result) / self.BYTES_PER_ELEMENT)
# party receives 1 object
return result
return func(self, *args, **kwargs)
return logging_wrapper