in torchbiggraph/train_cpu.py [0:0]
def train(self) -> None:
holder = self.holder
config = self.config
iteration_manager = self.iteration_manager
total_buckets = holder.nparts_lhs * holder.nparts_rhs
# yield stats from checkpoint, to reconstruct
# saved part of the learning curve
if self.rank == SINGLE_TRAINER:
for stats_dict in self.checkpoint_manager.maybe_read_stats():
index: int = stats_dict["index"]
stats: Optional[Stats] = None
if "stats" in stats_dict:
stats: Stats = Stats.from_dict(stats_dict["stats"])
eval_stats_before: Optional[Stats] = None
if "eval_stats_before" in stats_dict:
eval_stats_before = Stats.from_dict(stats_dict["eval_stats_before"])
eval_stats_after: Optional[Stats] = None
if "eval_stats_after" in stats_dict:
eval_stats_after = Stats.from_dict(stats_dict["eval_stats_after"])
eval_stats_chunk_avg: Optional[Stats] = None
if "eval_stats_chunk_avg" in stats_dict:
eval_stats_chunk_avg = Stats.from_dict(
stats_dict["eval_stats_chunk_avg"]
)
self.stats_handler.on_stats(
index,
eval_stats_before,
stats,
eval_stats_after,
eval_stats_chunk_avg,
)
for epoch_idx, edge_path_idx, edge_chunk_idx in iteration_manager:
logger.info(
f"Starting epoch {epoch_idx + 1} / {iteration_manager.num_epochs}, "
f"edge path {edge_path_idx + 1} / {iteration_manager.num_edge_paths}, "
f"edge chunk {edge_chunk_idx + 1} / {iteration_manager.num_edge_chunks}"
)
edge_storage = EDGE_STORAGES.make_instance(iteration_manager.edge_path)
logger.info(f"Edge path: {iteration_manager.edge_path}")
self._barrier()
dist_logger.info("Lock client new epoch...")
self.bucket_scheduler.new_pass(
is_first=iteration_manager.iteration_idx == 0
)
self._barrier()
remaining = total_buckets
cur_b: Optional[Bucket] = None
cur_stats: Optional[BucketStats] = None
while remaining > 0:
old_b: Optional[Bucket] = cur_b
old_stats: Optional[BucketStats] = cur_stats
cur_b, remaining = self.bucket_scheduler.acquire_bucket()
logger.info(f"still in queue: {remaining}")
if cur_b is None:
cur_stats = None
if old_b is not None:
# if you couldn't get a new pair, release the lock
# to prevent a deadlock!
tic = time.perf_counter()
release_bytes = self._swap_partitioned_embeddings(
old_b, None, old_stats
)
release_time = time.perf_counter() - tic
logger.info(
f"Swapping old embeddings to release lock. io: {release_time:.2f} s for {release_bytes:,} bytes "
f"( {release_bytes / release_time / 1e6:.2f} MB/sec )"
)
time.sleep(1) # don't hammer td
continue
tic = time.perf_counter()
self.cur_b = cur_b
bucket_logger = BucketLogger(logger, bucket=cur_b)
self.bucket_logger = bucket_logger
io_bytes = self._swap_partitioned_embeddings(old_b, cur_b, old_stats)
self.model.set_all_embeddings(holder, cur_b)
current_index = (
(iteration_manager.iteration_idx + 1) * total_buckets
- remaining
- 1
)
bucket_logger.debug("Loading edges")
edges = edge_storage.load_chunk_of_edges(
cur_b.lhs,
cur_b.rhs,
edge_chunk_idx,
iteration_manager.num_edge_chunks,
shared=True,
)
num_edges = len(edges)
# this might be off in the case of tensorlist or extra edge fields
io_bytes += edges.lhs.tensor.numel() * edges.lhs.tensor.element_size()
io_bytes += edges.rhs.tensor.numel() * edges.rhs.tensor.element_size()
io_bytes += edges.rel.numel() * edges.rel.element_size()
io_time = time.perf_counter() - tic
tic = time.perf_counter()
bucket_logger.debug("Shuffling edges")
# Fix a seed to get the same permutation every time; have it
# depend on all and only what affects the set of edges.
# Note: for the sake of efficiency, we sample eval edge idxs
# from the edge set *with replacement*, meaning that there may
# be duplicates of the same edge in the eval set. When we swap
# edges into the eval set, if there are duplicates then all
# but one will be clobbered. These collisions are unlikely
# if eval_fraction is small.
#
# Importantly, this eval sampling strategy is theoretically
# sound:
# * Training and eval sets are (exactly) disjoint
# * Eval set may have (rare) duplicates, but they are
# uniformly sampled so it's still an unbiased estimator
# of the out-of-sample statistics
num_eval_edges = int(num_edges * config.eval_fraction)
num_train_edges = num_edges - num_eval_edges
if num_eval_edges > 0:
g = torch.Generator()
g.manual_seed(
hash((edge_path_idx, edge_chunk_idx, cur_b.lhs, cur_b.rhs))
)
eval_edge_idxs = torch.randint(
num_edges, (num_eval_edges,), dtype=torch.long, generator=g
)
else:
eval_edge_idxs = None
# HOGWILD evaluation before training
eval_stats_before = self._coordinate_eval(edges, eval_edge_idxs)
if eval_stats_before is not None:
bucket_logger.info(f"Stats before training: {eval_stats_before}")
eval_time = time.perf_counter() - tic
tic = time.perf_counter()
# HOGWILD training
bucket_logger.debug("Waiting for workers to perform training")
stats = self._coordinate_train(edges, eval_edge_idxs, epoch_idx)
if stats is not None:
bucket_logger.info(f"Training stats: {stats}")
train_time = time.perf_counter() - tic
tic = time.perf_counter()
# HOGWILD evaluation after training
eval_stats_after = self._coordinate_eval(edges, eval_edge_idxs)
if eval_stats_after is not None:
bucket_logger.info(f"Stats after training: {eval_stats_after}")
eval_time += time.perf_counter() - tic
bucket_logger.info(
f"bucket {total_buckets - remaining} / {total_buckets} : "
f"Trained {num_train_edges} edges in {train_time:.2f} s "
f"( {num_train_edges / train_time / 1e6:.2g} M/sec ); "
f"Eval 2*{num_eval_edges} edges in {eval_time:.2f} s "
f"( {2 * num_eval_edges / eval_time / 1e6:.2g} M/sec ); "
f"io: {io_time:.2f} s for {io_bytes:,} bytes ( {io_bytes / io_time / 1e6:.2f} MB/sec )"
)
self.model.clear_all_embeddings()
cur_stats = BucketStats(
lhs_partition=cur_b.lhs,
rhs_partition=cur_b.rhs,
index=current_index,
train=stats,
eval_before=eval_stats_before,
eval_after=eval_stats_after,
)
# release the final bucket
self._swap_partitioned_embeddings(cur_b, None, cur_stats)
# Distributed Processing: all machines can leave the barrier now.
self._barrier()
current_index = (iteration_manager.iteration_idx + 1) * total_buckets - 1
self._maybe_write_checkpoint(
epoch_idx, edge_path_idx, edge_chunk_idx, current_index
)
# now we're sure that all partition files exist,
# so be strict about loading them
self.strict = True