in torchbiggraph/parameter_sharing.py [0:0]
def handle_get(self, rank: int, key: str, send_size: int) -> None:
if key not in self.parameters:
assert send_size, "Key %s not found" % key
td.send(torch.tensor([-1, -1], dtype=torch.long), rank)
return
data = self.parameters[key]
if send_size:
type_idx = _dtypes.index(data.dtype)
td.send(torch.tensor([data.ndimension(), type_idx], dtype=torch.long), rank)
td.send(torch.tensor(list(data.size()), dtype=torch.long), rank)
start_t = time.monotonic()
if self.groups is None:
td.send(data, dst=rank)
else:
outstanding_work = []
flattened_data = data.flatten()
flattened_size = flattened_data.shape[0]
for idx, (pg, slice_) in enumerate(
zip(
self.groups,
split_almost_equally(flattened_size, num_parts=len(self.groups)),
)
):
outstanding_work.append(
td.isend(tensor=flattened_data[slice_], dst=rank, group=pg, tag=idx)
)
for w in outstanding_work:
w.wait()
end_t = time.monotonic()
if self.log_stats:
stats_size = data.numel() * data.element_size()
stats_time = end_t - start_t
logger.debug(
f"Sent tensor {key} to client {rank}: "
f"{stats_size:,} bytes "
f"in {stats_time:,g} seconds "
f"=> {stats_size / stats_time:,.0f} B/s"
)