def train()

in torchbiggraph/train_cpu.py [0:0]


    def train(self) -> None:

        holder = self.holder
        config = self.config
        iteration_manager = self.iteration_manager

        total_buckets = holder.nparts_lhs * holder.nparts_rhs

        # yield stats from checkpoint, to reconstruct
        # saved part of the learning curve
        if self.rank == SINGLE_TRAINER:
            for stats_dict in self.checkpoint_manager.maybe_read_stats():
                index: int = stats_dict["index"]
                stats: Optional[Stats] = None
                if "stats" in stats_dict:
                    stats: Stats = Stats.from_dict(stats_dict["stats"])
                eval_stats_before: Optional[Stats] = None
                if "eval_stats_before" in stats_dict:
                    eval_stats_before = Stats.from_dict(stats_dict["eval_stats_before"])
                eval_stats_after: Optional[Stats] = None
                if "eval_stats_after" in stats_dict:
                    eval_stats_after = Stats.from_dict(stats_dict["eval_stats_after"])
                eval_stats_chunk_avg: Optional[Stats] = None
                if "eval_stats_chunk_avg" in stats_dict:
                    eval_stats_chunk_avg = Stats.from_dict(
                        stats_dict["eval_stats_chunk_avg"]
                    )
                self.stats_handler.on_stats(
                    index,
                    eval_stats_before,
                    stats,
                    eval_stats_after,
                    eval_stats_chunk_avg,
                )

        for epoch_idx, edge_path_idx, edge_chunk_idx in iteration_manager:
            logger.info(
                f"Starting epoch {epoch_idx + 1} / {iteration_manager.num_epochs}, "
                f"edge path {edge_path_idx + 1} / {iteration_manager.num_edge_paths}, "
                f"edge chunk {edge_chunk_idx + 1} / {iteration_manager.num_edge_chunks}"
            )
            edge_storage = EDGE_STORAGES.make_instance(iteration_manager.edge_path)
            logger.info(f"Edge path: {iteration_manager.edge_path}")

            self._barrier()
            dist_logger.info("Lock client new epoch...")
            self.bucket_scheduler.new_pass(
                is_first=iteration_manager.iteration_idx == 0
            )
            self._barrier()

            remaining = total_buckets
            cur_b: Optional[Bucket] = None
            cur_stats: Optional[BucketStats] = None
            while remaining > 0:
                old_b: Optional[Bucket] = cur_b
                old_stats: Optional[BucketStats] = cur_stats
                cur_b, remaining = self.bucket_scheduler.acquire_bucket()
                logger.info(f"still in queue: {remaining}")
                if cur_b is None:
                    cur_stats = None
                    if old_b is not None:
                        # if you couldn't get a new pair, release the lock
                        # to prevent a deadlock!
                        tic = time.perf_counter()
                        release_bytes = self._swap_partitioned_embeddings(
                            old_b, None, old_stats
                        )
                        release_time = time.perf_counter() - tic
                        logger.info(
                            f"Swapping old embeddings to release lock. io: {release_time:.2f} s for {release_bytes:,} bytes "
                            f"( {release_bytes / release_time / 1e6:.2f} MB/sec )"
                        )
                    time.sleep(1)  # don't hammer td
                    continue

                tic = time.perf_counter()
                self.cur_b = cur_b
                bucket_logger = BucketLogger(logger, bucket=cur_b)
                self.bucket_logger = bucket_logger

                io_bytes = self._swap_partitioned_embeddings(old_b, cur_b, old_stats)
                self.model.set_all_embeddings(holder, cur_b)

                current_index = (
                    (iteration_manager.iteration_idx + 1) * total_buckets
                    - remaining
                    - 1
                )

                bucket_logger.debug("Loading edges")
                edges = edge_storage.load_chunk_of_edges(
                    cur_b.lhs,
                    cur_b.rhs,
                    edge_chunk_idx,
                    iteration_manager.num_edge_chunks,
                    shared=True,
                )
                num_edges = len(edges)

                # this might be off in the case of tensorlist or extra edge fields
                io_bytes += edges.lhs.tensor.numel() * edges.lhs.tensor.element_size()
                io_bytes += edges.rhs.tensor.numel() * edges.rhs.tensor.element_size()
                io_bytes += edges.rel.numel() * edges.rel.element_size()
                io_time = time.perf_counter() - tic
                tic = time.perf_counter()
                bucket_logger.debug("Shuffling edges")
                # Fix a seed to get the same permutation every time; have it
                # depend on all and only what affects the set of edges.

                # Note: for the sake of efficiency, we sample eval edge idxs
                # from the edge set *with replacement*, meaning that there may
                # be duplicates of the same edge in the eval set. When we swap
                # edges into the eval set, if there are duplicates then all
                # but one will be clobbered. These collisions are unlikely
                # if eval_fraction is small.
                #
                # Importantly, this eval sampling strategy is theoretically
                # sound:
                # * Training and eval sets are (exactly) disjoint
                # * Eval set may have (rare) duplicates, but they are
                #   uniformly sampled so it's still an unbiased estimator
                #   of the out-of-sample statistics
                num_eval_edges = int(num_edges * config.eval_fraction)
                num_train_edges = num_edges - num_eval_edges
                if num_eval_edges > 0:
                    g = torch.Generator()
                    g.manual_seed(
                        hash((edge_path_idx, edge_chunk_idx, cur_b.lhs, cur_b.rhs))
                    )
                    eval_edge_idxs = torch.randint(
                        num_edges, (num_eval_edges,), dtype=torch.long, generator=g
                    )
                else:
                    eval_edge_idxs = None

                # HOGWILD evaluation before training
                eval_stats_before = self._coordinate_eval(edges, eval_edge_idxs)
                if eval_stats_before is not None:
                    bucket_logger.info(f"Stats before training: {eval_stats_before}")
                eval_time = time.perf_counter() - tic
                tic = time.perf_counter()

                # HOGWILD training
                bucket_logger.debug("Waiting for workers to perform training")
                stats = self._coordinate_train(edges, eval_edge_idxs, epoch_idx)
                if stats is not None:
                    bucket_logger.info(f"Training stats: {stats}")
                train_time = time.perf_counter() - tic
                tic = time.perf_counter()

                # HOGWILD evaluation after training
                eval_stats_after = self._coordinate_eval(edges, eval_edge_idxs)
                if eval_stats_after is not None:
                    bucket_logger.info(f"Stats after training: {eval_stats_after}")

                eval_time += time.perf_counter() - tic

                bucket_logger.info(
                    f"bucket {total_buckets - remaining} / {total_buckets} : "
                    f"Trained {num_train_edges} edges in {train_time:.2f} s "
                    f"( {num_train_edges / train_time / 1e6:.2g} M/sec ); "
                    f"Eval 2*{num_eval_edges} edges in {eval_time:.2f} s "
                    f"( {2 * num_eval_edges / eval_time / 1e6:.2g} M/sec ); "
                    f"io: {io_time:.2f} s for {io_bytes:,} bytes ( {io_bytes / io_time / 1e6:.2f} MB/sec )"
                )

                self.model.clear_all_embeddings()

                cur_stats = BucketStats(
                    lhs_partition=cur_b.lhs,
                    rhs_partition=cur_b.rhs,
                    index=current_index,
                    train=stats,
                    eval_before=eval_stats_before,
                    eval_after=eval_stats_after,
                )

            # release the final bucket
            self._swap_partitioned_embeddings(cur_b, None, cur_stats)

            # Distributed Processing: all machines can leave the barrier now.
            self._barrier()

            current_index = (iteration_manager.iteration_idx + 1) * total_buckets - 1

            self._maybe_write_checkpoint(
                epoch_idx, edge_path_idx, edge_chunk_idx, current_index
            )

            # now we're sure that all partition files exist,
            # so be strict about loading them
            self.strict = True