in dataflux_pytorch/benchmark/standalone_dataloader/standalone_dataloader.py [0:0]
def data_load_loop(config):
"""Main data loader loop.
Loads batches of data for each training step.
"""
# We seem to need the mesh to be setup for the distributed barrier to work.
# Therefore, simply calling this function but using our own iterator.
train.setup_mesh_and_model(config)
data_loader = parquet_data_loader(config)
max_steps = config.max_steps
# Record per-step per-epoch data loading times.
per_step_data_loading_time = []
# Record per-step times, which includes data loading, simulated computation time, and barrier wait time.
step_time = []
# Record per-epoch total time.
epoch_times = []
jax.experimental.multihost_utils.sync_global_devices(
"Barrier before training steps start")
training_start = datetime.datetime.now()
max_logging.log(
f"STANDALONE DATALOADER : Started training loop on host {jax.process_index()}"
)
global_steps = 0
for i in range(config.epochs):
epoch_start = datetime.datetime.now()
local_steps = 0
step_data_loading_start = datetime.datetime.now()
step_start = datetime.datetime.now()
for _ in data_loader:
step_data_loading_end = datetime.datetime.now()
data_loading_interval = (step_data_loading_end -
step_data_loading_start).total_seconds()
max_logging.log(
f"STANDALONE DATALOADER : Host {jax.process_index()} got a batch in {data_loading_interval} seconds on epoch {i} step {local_steps}"
)
per_step_data_loading_time.append(
[jax.process_index(), i, local_steps, data_loading_interval])
if jax.process_index() == 0:
if data_loading_interval < config.per_step_interval:
time.sleep(config.per_step_interval -
data_loading_interval)
step_barrier_wait(STEP_BARRIER_MSG, local_steps)
else:
step_barrier_wait(STEP_BARRIER_MSG, local_steps)
# Measure the per-step time.
step_end = datetime.datetime.now()
step_time.append([
jax.process_index(), i, local_steps,
(step_end - step_start).total_seconds()
])
step_start = step_end
# Reset the start time of computing data loading.
step_data_loading_start = datetime.datetime.now()
local_steps += 1
global_steps += 1
if max_steps > 0 and global_steps >= max_steps:
max_logging.log(
f"STANDALONE DATALOADER : {global_steps} global steps reached. Stopped training."
)
break
else:
# Only executed if the inner loop did NOT break.
measure_epoch_time(epoch=i,
epoch_start=epoch_start,
epoch_times=epoch_times)
continue
# Although the training has reached the global steps, we'd still want to
# log the current epoch time.
measure_epoch_time(epoch=i,
epoch_start=epoch_start,
epoch_times=epoch_times)
break
training_end = datetime.datetime.now()
training_time = (training_end - training_start).total_seconds()
max_logging.log(
f"STANDALONE DATALOADER Metrics: Training completed in {training_time} seconds, on host {jax.process_index()}"
)
max_logging.log(
f"STANDALONE DATALOADER Metrics: Per-epoch total times on host {jax.process_index()}: {epoch_times}."
)
max_logging.log(
f"STANDALONE DATALOADER Metrics: Per-step data loading times on host {jax.process_index()}: {per_step_data_loading_time}."
)
max_logging.log(
f"STANDALONE DATALOADER Metrics: Per-step times on host {jax.process_index()}: {step_time}.",
)
if config.gcs_metrics_bucket:
max_logging.log(
f"Uploading metrics to GCS bucket {config.gcs_metrics_bucket} on host {jax.process_index()}"
)
base_name = f"{jax.process_index()}.csv"
# Upload total training time.
storage_utils.upload_csv(
config.gcs_metrics_bucket,
os.path.join(config.run_name, TOTAL_TRAINING_TIME_DIRECTORY,
base_name), [[jax.process_index(), training_time]])
# Upload per-step data loading time.
storage_utils.upload_csv(
config.gcs_metrics_bucket,
os.path.join(config.run_name, PER_STEP_DATA_LOADING_TIME_DIRECTORY,
base_name), per_step_data_loading_time)
# Upload per-step total time.
storage_utils.upload_csv(
config.gcs_metrics_bucket,
os.path.join(config.run_name, PER_STEP_TIME_DIRECTORY, base_name),
step_time)
# Upload per epoch time.
storage_utils.upload_csv(
config.gcs_metrics_bucket,
os.path.join(config.run_name, PER_EPOCH_TIME_DIRECTORY, base_name),
epoch_times)
max_logging.log(
f"Finished uploading metrics to GCS bucket {config.gcs_metrics_bucket} on host {jax.process_index()}"
)
max_logging.log(f"Run name for accessing metrics: {config.run_name}")