in torchbiggraph/train_gpu.py [0:0]
def _coordinate_train(self, edges, eval_edge_idxs, epoch_idx) -> Stats:
tk = TimeKeeper()
config = self.config
holder = self.holder
cur_b = self.cur_b
bucket_logger = self.bucket_logger
num_edges = len(edges)
if cur_b.lhs == cur_b.rhs and config.num_gpus > 1:
num_subparts = 2 * config.num_gpus
else:
num_subparts = config.num_gpus
edges_lhs = edges.lhs.tensor
edges_rhs = edges.rhs.tensor
edges_rel = edges.rel
eval_edges_lhs = None
eval_edges_rhs = None
eval_edges_rel = None
assert edges.weight is None, "Edge weights not implemented in GPU mode yet"
if eval_edge_idxs is not None:
bucket_logger.debug("Removing eval edges")
tk.start("remove_eval")
num_eval_edges = len(eval_edge_idxs)
eval_edges_lhs = edges_lhs[eval_edge_idxs]
eval_edges_rhs = edges_rhs[eval_edge_idxs]
eval_edges_rel = edges_rel[eval_edge_idxs]
edges_lhs[eval_edge_idxs] = edges_lhs[-num_eval_edges:].clone()
edges_rhs[eval_edge_idxs] = edges_rhs[-num_eval_edges:].clone()
edges_rel[eval_edge_idxs] = edges_rel[-num_eval_edges:].clone()
edges_lhs = edges_lhs[:-num_eval_edges]
edges_rhs = edges_rhs[:-num_eval_edges]
edges_rel = edges_rel[:-num_eval_edges]
bucket_logger.debug(
f"Time spent removing eval edges: {tk.stop('remove_eval'):.4f} s"
)
bucket_logger.debug("Splitting edges into sub-buckets")
tk.start("mapping_edges")
# randomly permute the entities, to get a random subbucketing
perm_holder = {}
rev_perm_holder = {}
for (entity, part), embs in holder.partitioned_embeddings.items():
perm = _C.randperm(self.entity_counts[entity][part], os.cpu_count())
_C.shuffle(embs, perm, os.cpu_count())
optimizer = self.trainer.partitioned_optimizers[entity, part]
(optimizer_state,) = optimizer.state.values()
_C.shuffle(optimizer_state["sum"], perm, os.cpu_count())
perm_holder[entity, part] = perm
rev_perm = _C.reverse_permutation(perm, os.cpu_count())
rev_perm_holder[entity, part] = rev_perm
subpart_slices: Dict[Tuple[EntityName, Partition, SubPartition], slice] = {}
for entity_name, part in holder.partitioned_embeddings.keys():
num_entities = self.entity_counts[entity_name][part]
for subpart, subpart_slice in enumerate(
split_almost_equally(num_entities, num_parts=num_subparts)
):
subpart_slices[entity_name, part, subpart] = subpart_slice
subbuckets = _C.sub_bucket(
edges_lhs,
edges_rhs,
edges_rel,
[self.entity_counts[r.lhs][cur_b.lhs] for r in config.relations],
[perm_holder[r.lhs, cur_b.lhs] for r in config.relations],
[self.entity_counts[r.rhs][cur_b.rhs] for r in config.relations],
[perm_holder[r.rhs, cur_b.rhs] for r in config.relations],
self.shared_lhs,
self.shared_rhs,
self.shared_rel,
num_subparts,
num_subparts,
os.cpu_count(),
config.dynamic_relations,
)
bucket_logger.debug(
"Time spent splitting edges into sub-buckets: "
f"{tk.stop('mapping_edges'):.4f} s"
)
bucket_logger.debug("Done splitting edges into sub-buckets")
bucket_logger.debug(f"{subpart_slices}")
tk.start("scheduling")
busy_gpus: Set[int] = set()
all_stats: List[Stats] = []
if cur_b.lhs != cur_b.rhs: # Graph is bipartite!!
gpu_schedules = build_bipartite_schedule(num_subparts)
else:
gpu_schedules = build_nonbipartite_schedule(num_subparts)
for s in gpu_schedules:
s.append(None)
s.append(None)
index_in_schedule = [0 for _ in range(self.gpu_pool.num_gpus)]
locked_parts = set()
def schedule(gpu_idx: GPURank) -> None:
if gpu_idx in busy_gpus:
return
this_bucket = gpu_schedules[gpu_idx][index_in_schedule[gpu_idx]]
next_bucket = gpu_schedules[gpu_idx][index_in_schedule[gpu_idx] + 1]
if this_bucket is None:
return
subparts = {
(e, cur_b.lhs, this_bucket[0]) for e in holder.lhs_partitioned_types
} | {(e, cur_b.rhs, this_bucket[1]) for e in holder.rhs_partitioned_types}
if any(k in locked_parts for k in subparts):
return
for k in subparts:
locked_parts.add(k)
busy_gpus.add(gpu_idx)
bucket_logger.debug(
f"GPU #{gpu_idx} gets {this_bucket[0]}, {this_bucket[1]}"
)
for embs in holder.partitioned_embeddings.values():
assert embs.is_shared()
self.gpu_pool.schedule(
gpu_idx,
SubprocessArgs(
lhs_types=holder.lhs_partitioned_types,
rhs_types=holder.rhs_partitioned_types,
lhs_part=cur_b.lhs,
rhs_part=cur_b.rhs,
lhs_subpart=this_bucket[0],
rhs_subpart=this_bucket[1],
next_lhs_subpart=next_bucket[0]
if next_bucket is not None
else None,
next_rhs_subpart=next_bucket[1]
if next_bucket is not None
else None,
trainer=self.trainer,
model=self.model,
all_embs=holder.partitioned_embeddings,
subpart_slices=subpart_slices,
subbuckets=subbuckets,
batch_size=config.batch_size,
lr=config.lr,
),
)
for gpu_idx in range(self.gpu_pool.num_gpus):
schedule(gpu_idx)
while busy_gpus:
gpu_idx, result = self.gpu_pool.wait_for_next()
assert gpu_idx == result.gpu_idx
all_stats.append(result.stats)
busy_gpus.remove(gpu_idx)
this_bucket = gpu_schedules[gpu_idx][index_in_schedule[gpu_idx]]
next_bucket = gpu_schedules[gpu_idx][index_in_schedule[gpu_idx] + 1]
subparts = {
(e, cur_b.lhs, this_bucket[0]) for e in holder.lhs_partitioned_types
} | {(e, cur_b.rhs, this_bucket[1]) for e in holder.rhs_partitioned_types}
for k in subparts:
locked_parts.remove(k)
index_in_schedule[gpu_idx] += 1
if next_bucket is None:
bucket_logger.debug(f"GPU #{gpu_idx} finished its schedule")
for gpu_idx in range(config.num_gpus):
schedule(gpu_idx)
assert len(all_stats) == num_subparts * num_subparts
time_spent_scheduling = tk.stop("scheduling")
bucket_logger.debug(
f"Time spent scheduling sub-buckets: {time_spent_scheduling:.4f} s"
)
bucket_logger.info(f"Speed: {num_edges / time_spent_scheduling:,.0f} edges/sec")
tk.start("rev_perm")
for (entity, part), embs in holder.partitioned_embeddings.items():
rev_perm = rev_perm_holder[entity, part]
optimizer = self.trainer.partitioned_optimizers[entity, part]
_C.shuffle(embs, rev_perm, os.cpu_count())
(state,) = optimizer.state.values()
_C.shuffle(state["sum"], rev_perm, os.cpu_count())
bucket_logger.debug(
f"Time spent mapping embeddings back from sub-buckets: {tk.stop('rev_perm'):.4f} s"
)
if eval_edge_idxs is not None:
bucket_logger.debug("Restoring eval edges")
tk.start("restore_eval")
edges.lhs.tensor[eval_edge_idxs] = eval_edges_lhs
edges.rhs.tensor[eval_edge_idxs] = eval_edges_rhs
edges.rel[eval_edge_idxs] = eval_edges_rel
bucket_logger.debug(
f"Time spent restoring eval edges: {tk.stop('restore_eval'):.4f} s"
)
logger.debug(
f"_coordinate_train: Time unaccounted for: {tk.unaccounted():.4f} s"
)
return Stats.sum(all_stats).average()