def data_load_loop()

in dataflux_pytorch/benchmark/standalone_dataloader/standalone_dataloader.py [0:0]


def data_load_loop(config):
    """Main data loader loop.
  Loads batches of data for each training step.
  """
    # We seem to need the mesh to be setup for the distributed barrier to work.
    # Therefore, simply calling this function but using our own iterator.
    train.setup_mesh_and_model(config)
    data_loader = parquet_data_loader(config)

    max_steps = config.max_steps

    # Record per-step per-epoch data loading times.
    per_step_data_loading_time = []

    # Record per-step times, which includes data loading, simulated computation time, and barrier wait time.
    step_time = []

    # Record per-epoch total time.
    epoch_times = []

    jax.experimental.multihost_utils.sync_global_devices(
        "Barrier before training steps start")
    training_start = datetime.datetime.now()
    max_logging.log(
        f"STANDALONE DATALOADER : Started training loop on host {jax.process_index()}"
    )
    global_steps = 0
    for i in range(config.epochs):
        epoch_start = datetime.datetime.now()
        local_steps = 0
        step_data_loading_start = datetime.datetime.now()
        step_start = datetime.datetime.now()
        for _ in data_loader:
            step_data_loading_end = datetime.datetime.now()
            data_loading_interval = (step_data_loading_end -
                                     step_data_loading_start).total_seconds()
            max_logging.log(
                f"STANDALONE DATALOADER : Host {jax.process_index()} got a batch in {data_loading_interval} seconds on epoch {i} step {local_steps}"
            )
            per_step_data_loading_time.append(
                [jax.process_index(), i, local_steps, data_loading_interval])

            if jax.process_index() == 0:
                if data_loading_interval < config.per_step_interval:
                    time.sleep(config.per_step_interval -
                               data_loading_interval)
                step_barrier_wait(STEP_BARRIER_MSG, local_steps)
            else:
                step_barrier_wait(STEP_BARRIER_MSG, local_steps)

            # Measure the per-step time.
            step_end = datetime.datetime.now()
            step_time.append([
                jax.process_index(), i, local_steps,
                (step_end - step_start).total_seconds()
            ])
            step_start = step_end
            # Reset the start time of computing data loading.
            step_data_loading_start = datetime.datetime.now()

            local_steps += 1
            global_steps += 1
            if max_steps > 0 and global_steps >= max_steps:
                max_logging.log(
                    f"STANDALONE DATALOADER : {global_steps} global steps reached. Stopped training."
                )
                break
        else:
            # Only executed if the inner loop did NOT break.
            measure_epoch_time(epoch=i,
                               epoch_start=epoch_start,
                               epoch_times=epoch_times)
            continue
        # Although the training has reached the global steps, we'd still want to
        # log the current epoch time.
        measure_epoch_time(epoch=i,
                           epoch_start=epoch_start,
                           epoch_times=epoch_times)
        break

    training_end = datetime.datetime.now()
    training_time = (training_end - training_start).total_seconds()

    max_logging.log(
        f"STANDALONE DATALOADER Metrics: Training completed in {training_time} seconds, on host {jax.process_index()}"
    )
    max_logging.log(
        f"STANDALONE DATALOADER Metrics: Per-epoch total times on host {jax.process_index()}: {epoch_times}."
    )
    max_logging.log(
        f"STANDALONE DATALOADER Metrics: Per-step data loading times on host {jax.process_index()}: {per_step_data_loading_time}."
    )
    max_logging.log(
        f"STANDALONE DATALOADER Metrics: Per-step times on host {jax.process_index()}: {step_time}.",
    )

    if config.gcs_metrics_bucket:
        max_logging.log(
            f"Uploading metrics to GCS bucket {config.gcs_metrics_bucket} on host {jax.process_index()}"
        )
        base_name = f"{jax.process_index()}.csv"
        # Upload total training time.
        storage_utils.upload_csv(
            config.gcs_metrics_bucket,
            os.path.join(config.run_name, TOTAL_TRAINING_TIME_DIRECTORY,
                         base_name), [[jax.process_index(), training_time]])

        # Upload per-step data loading time.
        storage_utils.upload_csv(
            config.gcs_metrics_bucket,
            os.path.join(config.run_name, PER_STEP_DATA_LOADING_TIME_DIRECTORY,
                         base_name), per_step_data_loading_time)

        # Upload per-step total time.
        storage_utils.upload_csv(
            config.gcs_metrics_bucket,
            os.path.join(config.run_name, PER_STEP_TIME_DIRECTORY, base_name),
            step_time)

        # Upload per epoch time.
        storage_utils.upload_csv(
            config.gcs_metrics_bucket,
            os.path.join(config.run_name, PER_EPOCH_TIME_DIRECTORY, base_name),
            epoch_times)

        max_logging.log(
            f"Finished uploading metrics to GCS bucket {config.gcs_metrics_bucket} on host {jax.process_index()}"
        )
        max_logging.log(f"Run name for accessing metrics: {config.run_name}")