run_benchmark.py (62 lines of code) (raw):
import random
import time
import torch
from torch.profiler import profile, record_function, ProfilerActivity
from utils.benchmark_utils import annotate, create_parser
from utils.pipeline_utils import load_pipeline # noqa: E402
def _determine_pipe_call_kwargs(args):
kwargs = {"max_sequence_length": 256, "guidance_scale": 0.0}
ckpt_id = args.ckpt
if ckpt_id == "black-forest-labs/FLUX.1-dev":
kwargs = {"max_sequence_length": 512, "guidance_scale": 3.5}
return kwargs
def set_rand_seeds(seed):
random.seed(seed)
torch.manual_seed(seed)
def main(args):
set_rand_seeds(args.seed)
pipeline = load_pipeline(args)
set_rand_seeds(args.seed)
# warmup
for _ in range(3):
image = pipeline(
args.prompt,
num_inference_steps=args.num_inference_steps,
generator=torch.manual_seed(0),
**_determine_pipe_call_kwargs(args)
).images[0]
# run inference 10 times and compute mean / variance
timings = []
for _ in range(10):
begin = time.time()
image = pipeline(
args.prompt,
num_inference_steps=args.num_inference_steps,
generator=torch.manual_seed(0),
**_determine_pipe_call_kwargs(args)
).images[0]
end = time.time()
timings.append(end - begin)
timings = torch.tensor(timings)
print('time mean/var:', timings, timings.mean().item(), timings.var().item())
image.save(args.output_file)
# optionally generate PyTorch Profiler trace
# this is done after benchmarking because tracing introduces overhead
if args.trace_file is not None:
# annotate parts of the model within the profiler trace
pipeline.transformer.forward = annotate(pipeline.transformer.forward, "denoising_step")
pipeline.vae.decode = annotate(pipeline.vae.decode, "decoding")
pipeline.encode_prompt = annotate(pipeline.encode_prompt, "prompt_encoding")
pipeline.image_processor.postprocess = annotate(
pipeline.image_processor.postprocess, "postprocessing"
)
pipeline.image_processor.numpy_to_pil = annotate(
pipeline.image_processor.numpy_to_pil, "pil_conversion"
)
# Generate trace with the PyTorch Profiler
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
with record_function("timed_region"):
image = pipeline(
args.prompt,
num_inference_steps=args.num_inference_steps,
**_determine_pipe_call_kwargs(args)
).images[0]
prof.export_chrome_trace(args.trace_file)
if __name__ == "__main__":
parser = create_parser()
args = parser.parse_args()
main(args)