gen_image.py (20 lines of code) (raw):

import random import time import torch from torch.profiler import profile, record_function, ProfilerActivity from utils.benchmark_utils import annotate, create_parser from utils.pipeline_utils import load_pipeline # noqa: E402 def set_rand_seeds(seed): random.seed(seed) torch.manual_seed(seed) def main(args): pipeline = load_pipeline(args) set_rand_seeds(args.seed) image = pipeline( args.prompt, num_inference_steps=args.num_inference_steps, guidance_scale=0.0 ).images[0] image.save(args.output_file) if __name__ == "__main__": parser = create_parser() args = parser.parse_args() main(args)