ssiog/training.py (363 lines of code) (raw):
#!/usr/bin/env python3
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A Synthetic Scale IO Generator for training workloads.
"""
import argparse
import fsspec
import datetime
import queue
import random
import sys
import traceback
import threading
import time
from typing import Iterable
import logging
import arguments
import util
import gcsfs
import torch.distributed as td
import pyarrow.fs as fs
import monitoring
from opentelemetry import metrics
import metrics_logger
# Import the GCP resource detector
# TODO(coryan) - the sample size and batch size should be randomly sampled
# TODO (raj-prince) - write a development guide.
# TODO (raj-prince) - clear the logging path.
# TODO (raj-prince) - overall testing on scale.
# TODO (raj-prince) - See how to write unit test.
# Global for recording sample latency to export.
sample_lat = metrics.NoOpHistogram("no_op")
# Initialize the global metrics logger with no-op logger.
sample_lat_logger = metrics_logger.NoOpMetricsLogger()
# Initialize the global logger with basic INFO level log.
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s')
logger = logging.getLogger(__name__)
class Source(object):
def __init__(self, name: str, filesystem: fs.FileSystem, objects: Iterable[str]):
self.name = name
self.filesystem = filesystem
self.objects = list(objects)
def setup_metrics_exporter(args):
# Initialize the OpenTelemetry MeterProvider
meter = monitoring.initialize_monitoring_provider(exporter_type=args.exporter_type)
# Create a histogram metric
global sample_lat
sample_lat = meter.create_histogram(
name="ssiog.sample_lat",
description="Sample latency histogram",
unit="ms"
)
logger.info("Metrics exporter initialized.")
def setup_metrics_logger(args):
global sample_lat_logger
sample_lat_logger = metrics_logger.AsyncMetricsLogger(file_name=args.metrics_file)
logger.info("Metrics logger initialized.")
def setup_logger(args):
global logger
logger = logging.getLogger(args.label)
# No propagation in the logger hierarchy.
logger.propagate = False
# Log level.
log_level = getattr(logging, args.log_level)
logger.setLevel(log_level)
# Log destination, where to write?
handler = logging.FileHandler(args.log_file) if args.log_file else logging.StreamHandler()
# Beautify.
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.info("Logger initialized.")
def close_metrics_logger():
global sample_lat_logger
sample_lat_logger.close()
def training():
# Parse arguments
args = arguments.parse_args()
# Initialize the global application logger.
logger.info("Setting up logger.")
setup_logger(args)
logger.debug(f"Running with args: {args}")
# Initialize the OpenTelemetry MeterProvider
if args.export_metrics:
logger.info("Setting up otlp metrics exporter.")
setup_metrics_exporter(args)
# Initialize the metrics logger.
if args.log_metrics:
logger.info(f"Logging metrics to: {args.metrics_file}")
setup_metrics_logger(args)
logger.info("Initial setup completed.\n")
logger.info(f"Starting process: {args.group_member_id}/{args.group_size}")
td.init_process_group(
"gloo",
init_method=f"tcp://{args.group_coordinator_address}:{args.group_coordinator_port}",
rank=args.group_member_id,
world_size=args.group_size,
)
logger.info(f"Process started successfully: {args.group_member_id}/{args.group_size}\n")
logger.info(f"Logging important workload configurations.")
logger.info(f"Total epochs: {args.epochs}")
logger.info(f"Sample size (bytes): {args.sample_size}")
logger.info(f"Batch size: {args.batch_size}")
logger.info(f"Steps: {args.steps}")
logger.info(f"Read order: {args.read_order[0]}")
logger.info(f"Background queue max size: {args.background_queue_maxsize}")
logger.info(f"Background threads: {args.background_threads}")
logger.info(f"Group member id: {args.group_member_id}")
logger.info(f"Group size: {args.group_size}")
logger.info(f"Label: {args.label}")
logger.info(f"Data set path: {args.prefix}.\n")
sources = configure_object_sources(args)
for epoch in range(args.epochs):
logger.info(f"******** Starting epoch: {epoch} ********.")
logger.info(f"Configure epoch: {epoch}.")
(reader, read_order, filesystem_name, filesystem, epoch_objects) = (configure_epoch(sources, args))
logger.info(f"Configured, total objects: {len(epoch_objects)}")
logger.info(f"Configuring samples.")
samples = configure_samples(epoch_objects, filesystem, args)
logger.info(f"Configured, total selected samples: {len(samples)}")
logger.info(f"Running epoch: {epoch}")
for summary in Epoch(reader, epoch_objects, filesystem, samples, args):
logger.info(f"Epoch: {epoch}, {summary}")
logger.info(f"Epoch {epoch} completed.\n")
# Clear the kernel cache
if args.clear_pagecache_after_epoch:
util.clear_kernel_cache(logger)
def Epoch(
reader: callable,
epoch_objects: Iterable[str],
filesystem: fs.PyFileSystem,
samples: list,
args: argparse.Namespace,
):
q = queue.Queue(maxsize=args.background_queue_maxsize)
for i in range(args.background_threads):
threading.Thread(
daemon=True,
target=_background,
args=(
reader,
q,
epoch_objects,
i,
args.background_threads,
filesystem,
args.sample_size,
samples,
),
).start()
step_start = time.monotonic_ns()
step = 0
running = args.background_threads
batch_samples = 0
remaining = len(samples)
logger.debug("Starting the steps loop.")
while running != 0 and step < args.steps:
item = q.get()
if isinstance(item, Failed):
raise Exception("One of the background threads failed.")
if isinstance(item, Done):
q.task_done()
running -= 1
continue
q.task_done()
batch_samples += 1
remaining -= args.batch_size
if batch_samples < args.batch_size:
continue
duration_ns = time.monotonic_ns() - step_start
yield f"Step: {step}, Duration (ms): {duration_ns/1000000}, Batch-sample: {batch_samples}"
if td.get_world_size() > 1:
td.barrier()
step_start = time.monotonic_ns()
step += 1
batch_samples = 0
for i in range(step, args.steps):
logger.info(f"Empty step {i}")
if td.get_world_size() > 1:
td.barrier()
class Done(object):
pass
class Failed(object):
pass
def _subset(samples: Iterable, index: int, count: int) -> list[str]:
return [o for i, o in enumerate(samples) if i % count == index]
def _background(
reader: callable,
queue: queue.Queue,
object_names: Iterable[str],
thread_id: int,
thread_count: int,
filesystem: fs.FileSystem,
sample_size: int,
samples: list,
):
logger.debug(f"Background thread {thread_id} started.")
try:
success = True
for r in reader(object_names, thread_id, thread_count, filesystem, sample_size, samples):
queue.put(r)
except Exception as e:
success = False
queue.put(Failed())
logger.error(f"Background thread {thread_id} failed: {e}")
finally:
queue.put(Done())
if success:
logger.debug(f"Background thread {thread_id} completed.")
def sequential_reader(
object_names: Iterable[str],
thread_id: int,
thread_count: int,
filesystem: fs.FileSystem,
sample_size: int,
samples: list,
):
subset = _subset(object_names, td.get_rank(), td.get_world_size())
subset = _subset(subset, thread_id, thread_count)
for name in subset:
# Only read as many samples as have been configured for this object.
max_offset = sample_size * len([o for n, o in samples if n == name])
logger.debug(f"Reading {name} sequentially from {0} to {max_offset}.")
with filesystem.open_input_stream(name) as f:
offset = 0
while offset < max_offset:
start_time = time.monotonic_ns()
chunk = f.read(sample_size)
elapsed_time = time.monotonic_ns() - start_time
sample_lat_logger.log_metric(elapsed_time / 1000000)
sample_lat.record(elapsed_time / 1000000, {"reader": "sequential"})
if not chunk:
break
yield (name, offset, elapsed_time)
offset += len(chunk)
# TODO (raj-prince): discuss what observability is required for file_random_reader pattern.
def file_random_reader(
object_names: Iterable[str],
thread_id: int,
thread_count: int,
filesystem: fs.FileSystem,
sample_size: int,
samples: list,
):
subset = _subset(object_names, td.get_rank(), td.get_world_size())
subset = _subset(subset, thread_id, thread_count)
for name in subset:
data = filesystem.open_input_file(name).readall()
offsets = [o for n, o in samples if n == name]
for offset in offsets:
chunk = data[offset : min(len(data), offset + sample_size)]
yield (offset, chunk)
del offsets
del data
def full_random_reader(
object_names: Iterable[str],
thread_id: int,
thread_count: int,
filesystem: fs.FileSystem,
sample_size: int,
samples: list,
):
files = {n: filesystem.open_input_file(n) for n in object_names}
subset = _subset(samples, td.get_rank(), td.get_world_size())
subset = _subset(subset, thread_id, thread_count)
for name, offset in subset:
logger.debug(f"Reading {name} at {offset} with size {sample_size}.")
start_time = time.monotonic_ns()
try:
chunk = files[name].read_at(sample_size, offset)
except Exception as e:
logger.error(f"error in reading {name} at {offset} with size {sample_size}: {e}")
raise
elapsed_time = time.monotonic_ns() - start_time
sample_lat_logger.log_metric(elapsed_time / 1000000)
sample_lat.record(elapsed_time / 1000000, {"reader": "full_random"})
logger.debug(f"Complete reading {name} at {offset} with size {sample_size} in {elapsed_time / 1000000} ms.")
if not chunk:
logger.error(f"Chunk is nil.")
raise ValueError("chunk is nil.")
yield (offset, chunk)
for name, f in files.items():
f.close()
del files
del samples
def configure_samples(
object_names: Iterable[str], filesystem: fs.FileSystem, args: argparse.Namespace
):
samples = []
logger.info(f"Opening {len(object_names)} files.")
files = {n: filesystem.open_input_file(n) for n in object_names}
req_samples = args.batch_size * args.steps * args.group_size
logger.info(f"Collecting {req_samples} samples.")
for name, f in files.items():
samples.extend([(name, offset) for offset in range(0, f.size(), args.sample_size)])
logger.info(f"Total samples: {len(samples)}")
logger.info(f"Selecting {req_samples} samples from {len(samples)} randomly.")
if req_samples > len(samples):
logger.warning(f"Req sample ({req_samples}) > available ({len(samples)}), hence duplicated.")
sample_selection_stime = time.monotonic_ns()
samples = random.choices(samples, k=req_samples)
sample_selection_etime = time.monotonic_ns()
logger.info(f"Sample selection took {(sample_selection_etime - sample_selection_stime) / 1000000} ms.")
if td.get_world_size() > 1:
broadcast_time_start = time.monotonic_ns()
td.broadcast_object_list(samples, src=0)
td.barrier()
broadcast_time_end = time.monotonic_ns()
logger.info(f"Sample broadcast took {(broadcast_time_end - broadcast_time_start) / 1000000} ms.")
else:
logger.info("Broadcasting[samples] is not required as world size is 1.")
return samples
def configure_epoch(sources: dict[str, Source], args: argparse.Namespace):
prefix = [random.choice(args.prefix)]
if td.get_world_size() > 1:
broadcast_time_start = time.monotonic_ns()
td.broadcast_object_list(prefix, src=0)
td.barrier()
broadcast_time_end = time.monotonic_ns()
logger.info(f"Prefix broadcast took {(broadcast_time_end - broadcast_time_start) / 1000000} ms.")
else:
logger.info("Broadcasting[prefix] is not required as world size is 1.")
p = prefix[0]
name = sources[p].name
filesystem = sources[p].filesystem
epoch_objects = sources[p].objects.copy()
random.shuffle(epoch_objects)
if len(epoch_objects) > args.object_count_limit:
epoch_objects = epoch_objects[0 : args.object_count_limit]
if td.get_world_size() > 1:
broadcast_time_start = time.monotonic_ns()
td.broadcast_object_list(epoch_objects, src=0)
td.barrier()
broadcast_time_end = time.monotonic_ns()
logger.info(f"Epoch-objects broadcast took {(broadcast_time_end - broadcast_time_start) / 1000000} ms.")
else:
logger.info("Broadcasting[epoch-objects] is not required as world size is 1.")
read_order = [random.choice(args.read_order)]
if td.get_world_size() > 1:
broadcast_time_start = time.monotonic_ns()
td.broadcast_object_list(read_order, src=0)
td.barrier()
broadcast_time_end = time.monotonic_ns()
logger.info(f"Read-order broadcast took {(broadcast_time_end - broadcast_time_start) / 1000000} ms.")
else:
logger.info("Broadcasting[read-order] is not required as world size is 1.")
if read_order[0] == "Sequential":
reader = sequential_reader
elif read_order[0] == "FileRandom":
reader = file_random_reader
elif read_order[0] == "FullRandom":
reader = full_random_reader
else:
raise Exception(f"Unknown reading order {read_order[0]}")
return (reader, read_order[0], name, filesystem, epoch_objects)
def configure_object_sources(args: argparse.Namespace) -> dict[str, Source]:
sources = dict()
for prefix in args.prefix:
if prefix.startswith("gs://"):
objects = fsspec.filesystem("gcs").ls(prefix.removeprefix("gs://"))
sources[prefix] = Source("gcs", fs.GcsFileSystem(), objects)
elif prefix.startswith("gcsfs://"):
sources[prefix] = Source(
"fsspec",
fs.PyFileSystem(fs.FSSpecHandler(gcsfs.GCSFileSystem())),
fsspec.filesystem("gcs").ls(prefix.removeprefix("gcsfs://")),
)
else:
sources[prefix] = Source(
"local", fs.LocalFileSystem(), fsspec.filesystem("local").ls(prefix)
)
return sources
def main():
logger.info("testing....logger")
try:
success = True
training()
except Exception as e:
success = False
logger.error(f"Workload failed with error: {e}")
traceback.print_exc()
sys.exit(1)
finally:
# Make sure the flush the metrics in the buffer.
close_metrics_logger()
td.destroy_process_group()
if success:
logger.info("Workload completed successfully.")
sys.exit(0)
if __name__ == "__main__":
main()