def _coordinate_train()

in torchbiggraph/train_gpu.py [0:0]


    def _coordinate_train(self, edges, eval_edge_idxs, epoch_idx) -> Stats:
        tk = TimeKeeper()

        config = self.config
        holder = self.holder
        cur_b = self.cur_b
        bucket_logger = self.bucket_logger
        num_edges = len(edges)
        if cur_b.lhs == cur_b.rhs and config.num_gpus > 1:
            num_subparts = 2 * config.num_gpus
        else:
            num_subparts = config.num_gpus

        edges_lhs = edges.lhs.tensor
        edges_rhs = edges.rhs.tensor
        edges_rel = edges.rel
        eval_edges_lhs = None
        eval_edges_rhs = None
        eval_edges_rel = None
        assert edges.weight is None, "Edge weights not implemented in GPU mode yet"
        if eval_edge_idxs is not None:
            bucket_logger.debug("Removing eval edges")
            tk.start("remove_eval")
            num_eval_edges = len(eval_edge_idxs)
            eval_edges_lhs = edges_lhs[eval_edge_idxs]
            eval_edges_rhs = edges_rhs[eval_edge_idxs]
            eval_edges_rel = edges_rel[eval_edge_idxs]
            edges_lhs[eval_edge_idxs] = edges_lhs[-num_eval_edges:].clone()
            edges_rhs[eval_edge_idxs] = edges_rhs[-num_eval_edges:].clone()
            edges_rel[eval_edge_idxs] = edges_rel[-num_eval_edges:].clone()
            edges_lhs = edges_lhs[:-num_eval_edges]
            edges_rhs = edges_rhs[:-num_eval_edges]
            edges_rel = edges_rel[:-num_eval_edges]
            bucket_logger.debug(
                f"Time spent removing eval edges: {tk.stop('remove_eval'):.4f} s"
            )

        bucket_logger.debug("Splitting edges into sub-buckets")
        tk.start("mapping_edges")
        # randomly permute the entities, to get a random subbucketing
        perm_holder = {}
        rev_perm_holder = {}
        for (entity, part), embs in holder.partitioned_embeddings.items():
            perm = _C.randperm(self.entity_counts[entity][part], os.cpu_count())
            _C.shuffle(embs, perm, os.cpu_count())
            optimizer = self.trainer.partitioned_optimizers[entity, part]
            (optimizer_state,) = optimizer.state.values()
            _C.shuffle(optimizer_state["sum"], perm, os.cpu_count())
            perm_holder[entity, part] = perm
            rev_perm = _C.reverse_permutation(perm, os.cpu_count())
            rev_perm_holder[entity, part] = rev_perm

        subpart_slices: Dict[Tuple[EntityName, Partition, SubPartition], slice] = {}
        for entity_name, part in holder.partitioned_embeddings.keys():
            num_entities = self.entity_counts[entity_name][part]
            for subpart, subpart_slice in enumerate(
                split_almost_equally(num_entities, num_parts=num_subparts)
            ):
                subpart_slices[entity_name, part, subpart] = subpart_slice

        subbuckets = _C.sub_bucket(
            edges_lhs,
            edges_rhs,
            edges_rel,
            [self.entity_counts[r.lhs][cur_b.lhs] for r in config.relations],
            [perm_holder[r.lhs, cur_b.lhs] for r in config.relations],
            [self.entity_counts[r.rhs][cur_b.rhs] for r in config.relations],
            [perm_holder[r.rhs, cur_b.rhs] for r in config.relations],
            self.shared_lhs,
            self.shared_rhs,
            self.shared_rel,
            num_subparts,
            num_subparts,
            os.cpu_count(),
            config.dynamic_relations,
        )
        bucket_logger.debug(
            "Time spent splitting edges into sub-buckets: "
            f"{tk.stop('mapping_edges'):.4f} s"
        )
        bucket_logger.debug("Done splitting edges into sub-buckets")
        bucket_logger.debug(f"{subpart_slices}")

        tk.start("scheduling")
        busy_gpus: Set[int] = set()
        all_stats: List[Stats] = []
        if cur_b.lhs != cur_b.rhs:  # Graph is bipartite!!
            gpu_schedules = build_bipartite_schedule(num_subparts)
        else:
            gpu_schedules = build_nonbipartite_schedule(num_subparts)
        for s in gpu_schedules:
            s.append(None)
            s.append(None)
        index_in_schedule = [0 for _ in range(self.gpu_pool.num_gpus)]
        locked_parts = set()

        def schedule(gpu_idx: GPURank) -> None:
            if gpu_idx in busy_gpus:
                return
            this_bucket = gpu_schedules[gpu_idx][index_in_schedule[gpu_idx]]
            next_bucket = gpu_schedules[gpu_idx][index_in_schedule[gpu_idx] + 1]
            if this_bucket is None:
                return
            subparts = {
                (e, cur_b.lhs, this_bucket[0]) for e in holder.lhs_partitioned_types
            } | {(e, cur_b.rhs, this_bucket[1]) for e in holder.rhs_partitioned_types}
            if any(k in locked_parts for k in subparts):
                return
            for k in subparts:
                locked_parts.add(k)
            busy_gpus.add(gpu_idx)
            bucket_logger.debug(
                f"GPU #{gpu_idx} gets {this_bucket[0]}, {this_bucket[1]}"
            )
            for embs in holder.partitioned_embeddings.values():
                assert embs.is_shared()
            self.gpu_pool.schedule(
                gpu_idx,
                SubprocessArgs(
                    lhs_types=holder.lhs_partitioned_types,
                    rhs_types=holder.rhs_partitioned_types,
                    lhs_part=cur_b.lhs,
                    rhs_part=cur_b.rhs,
                    lhs_subpart=this_bucket[0],
                    rhs_subpart=this_bucket[1],
                    next_lhs_subpart=next_bucket[0]
                    if next_bucket is not None
                    else None,
                    next_rhs_subpart=next_bucket[1]
                    if next_bucket is not None
                    else None,
                    trainer=self.trainer,
                    model=self.model,
                    all_embs=holder.partitioned_embeddings,
                    subpart_slices=subpart_slices,
                    subbuckets=subbuckets,
                    batch_size=config.batch_size,
                    lr=config.lr,
                ),
            )

        for gpu_idx in range(self.gpu_pool.num_gpus):
            schedule(gpu_idx)

        while busy_gpus:
            gpu_idx, result = self.gpu_pool.wait_for_next()
            assert gpu_idx == result.gpu_idx
            all_stats.append(result.stats)
            busy_gpus.remove(gpu_idx)
            this_bucket = gpu_schedules[gpu_idx][index_in_schedule[gpu_idx]]
            next_bucket = gpu_schedules[gpu_idx][index_in_schedule[gpu_idx] + 1]
            subparts = {
                (e, cur_b.lhs, this_bucket[0]) for e in holder.lhs_partitioned_types
            } | {(e, cur_b.rhs, this_bucket[1]) for e in holder.rhs_partitioned_types}
            for k in subparts:
                locked_parts.remove(k)
            index_in_schedule[gpu_idx] += 1
            if next_bucket is None:
                bucket_logger.debug(f"GPU #{gpu_idx} finished its schedule")
            for gpu_idx in range(config.num_gpus):
                schedule(gpu_idx)

        assert len(all_stats) == num_subparts * num_subparts
        time_spent_scheduling = tk.stop("scheduling")
        bucket_logger.debug(
            f"Time spent scheduling sub-buckets: {time_spent_scheduling:.4f} s"
        )
        bucket_logger.info(f"Speed: {num_edges / time_spent_scheduling:,.0f} edges/sec")

        tk.start("rev_perm")

        for (entity, part), embs in holder.partitioned_embeddings.items():
            rev_perm = rev_perm_holder[entity, part]
            optimizer = self.trainer.partitioned_optimizers[entity, part]
            _C.shuffle(embs, rev_perm, os.cpu_count())
            (state,) = optimizer.state.values()
            _C.shuffle(state["sum"], rev_perm, os.cpu_count())

        bucket_logger.debug(
            f"Time spent mapping embeddings back from sub-buckets: {tk.stop('rev_perm'):.4f} s"
        )

        if eval_edge_idxs is not None:
            bucket_logger.debug("Restoring eval edges")
            tk.start("restore_eval")
            edges.lhs.tensor[eval_edge_idxs] = eval_edges_lhs
            edges.rhs.tensor[eval_edge_idxs] = eval_edges_rhs
            edges.rel[eval_edge_idxs] = eval_edges_rel
            bucket_logger.debug(
                f"Time spent restoring eval edges: {tk.stop('restore_eval'):.4f} s"
            )

        logger.debug(
            f"_coordinate_train: Time unaccounted for: {tk.unaccounted():.4f} s"
        )

        return Stats.sum(all_stats).average()