run_benchmark.py (62 lines of code) (raw):

import random import time import torch from torch.profiler import profile, record_function, ProfilerActivity from utils.benchmark_utils import annotate, create_parser from utils.pipeline_utils import load_pipeline # noqa: E402 def _determine_pipe_call_kwargs(args): kwargs = {"max_sequence_length": 256, "guidance_scale": 0.0} ckpt_id = args.ckpt if ckpt_id == "black-forest-labs/FLUX.1-dev": kwargs = {"max_sequence_length": 512, "guidance_scale": 3.5} return kwargs def set_rand_seeds(seed): random.seed(seed) torch.manual_seed(seed) def main(args): set_rand_seeds(args.seed) pipeline = load_pipeline(args) set_rand_seeds(args.seed) # warmup for _ in range(3): image = pipeline( args.prompt, num_inference_steps=args.num_inference_steps, generator=torch.manual_seed(0), **_determine_pipe_call_kwargs(args) ).images[0] # run inference 10 times and compute mean / variance timings = [] for _ in range(10): begin = time.time() image = pipeline( args.prompt, num_inference_steps=args.num_inference_steps, generator=torch.manual_seed(0), **_determine_pipe_call_kwargs(args) ).images[0] end = time.time() timings.append(end - begin) timings = torch.tensor(timings) print('time mean/var:', timings, timings.mean().item(), timings.var().item()) image.save(args.output_file) # optionally generate PyTorch Profiler trace # this is done after benchmarking because tracing introduces overhead if args.trace_file is not None: # annotate parts of the model within the profiler trace pipeline.transformer.forward = annotate(pipeline.transformer.forward, "denoising_step") pipeline.vae.decode = annotate(pipeline.vae.decode, "decoding") pipeline.encode_prompt = annotate(pipeline.encode_prompt, "prompt_encoding") pipeline.image_processor.postprocess = annotate( pipeline.image_processor.postprocess, "postprocessing" ) pipeline.image_processor.numpy_to_pil = annotate( pipeline.image_processor.numpy_to_pil, "pil_conversion" ) # Generate trace with the PyTorch Profiler with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: with record_function("timed_region"): image = pipeline( args.prompt, num_inference_steps=args.num_inference_steps, **_determine_pipe_call_kwargs(args) ).images[0] prof.export_chrome_trace(args.trace_file) if __name__ == "__main__": parser = create_parser() args = parser.parse_args() main(args)