in torchbiggraph/train_cpu.py [0:0]
def _coordinate_eval(self, edges, eval_edge_idxs) -> Optional[Stats]:
eval_batch_size = round_up_to_nearest_multiple(
self.config.batch_size, self.config.eval_num_batch_negs
)
if eval_edge_idxs is not None:
self.bucket_logger.debug("Waiting for workers to perform evaluation")
future_all_eval_stats = self.pool.map_async(
call,
[
partial(
process_in_batches,
batch_size=eval_batch_size,
model=self.model,
batch_processor=self.evaluator,
edges=edges,
indices=eval_edge_idxs[s],
)
for s in split_almost_equally(
eval_edge_idxs.size(0), num_parts=self.num_workers
)
],
)
all_eval_stats = get_async_result(future_all_eval_stats, self.pool)
return Stats.sum(all_eval_stats).average()
else:
return None