in torchbiggraph/train_cpu.py [0:0]
def _coordinate_train(self, edges, eval_edge_idxs, epoch_idx) -> Stats:
assert self.config.num_gpus == 0, "GPU training not supported"
if eval_edge_idxs is not None:
num_train_edges = len(edges) - len(eval_edge_idxs)
train_edge_idxs = torch.arange(len(edges))
train_edge_idxs[eval_edge_idxs] = torch.arange(num_train_edges, len(edges))
train_edge_idxs = train_edge_idxs[:num_train_edges]
edge_perm = train_edge_idxs[torch.randperm(num_train_edges)]
else:
edge_perm = torch.randperm(len(edges))
future_all_stats = self.pool.map_async(
call,
[
partial(
process_in_batches,
batch_size=self.config.batch_size,
model=self.model,
batch_processor=self.trainer,
edges=edges,
indices=edge_perm[s],
# FIXME should we only delay if iteration_idx == 0?
delay=self.config.hogwild_delay
if epoch_idx == 0 and self.rank > 0
else 0,
)
for rank, s in enumerate(
split_almost_equally(edge_perm.size(0), num_parts=self.num_workers)
)
],
)
all_stats = get_async_result(future_all_stats, self.pool)
return Stats.sum(all_stats).average()