in optimum_benchmark/scenarios/inference/scenario.py [0:0]
def run(self, backend: Backend[BackendConfigT]) -> BenchmarkReport:
self.backend = backend
if self.backend.config.task in TEXT_GENERATION_TASKS:
self.logger.info("\t+ Updating Text Generation kwargs with default values")
self.config.generate_kwargs = {**TEXT_GENERATION_DEFAULT_KWARGS, **self.config.generate_kwargs}
elif self.backend.config.task in IMAGE_DIFFUSION_TASKS:
self.logger.info("\t+ Updating Image Diffusion kwargs with default values")
self.config.call_kwargs = {**IMAGE_DIFFUSION_DEFAULT_KWARGS, **self.config.call_kwargs}
if self.backend.config.task in TEXT_GENERATION_TASKS:
self.logger.info("\t+ Initializing Text Generation targets list")
targets = ["load_model", "first_generate", "generate", "prefill", "decode"]
if self.backend.config.name in PER_TOKEN_BACKENDS:
targets.append("per_token")
elif self.backend.config.task in IMAGE_DIFFUSION_TASKS:
self.logger.info("\t+ Initializing Image Diffusion targets list")
targets = ["load_model", "first_call", "call", "per_step"]
else:
self.logger.info("\t+ Initializing Inference targets list")
targets = ["load_model", "first_forward", "forward"]
self.report = BenchmarkReport.from_list(targets=targets)
if self.config.latency:
self.logger.info("\t+ Initializing Latency tracker")
self.latency_tracker = LatencySessionTracker(
device=self.backend.config.device, backend=self.backend.config.name
)
if self.backend.config.task in TEXT_GENERATION_TASKS and self.backend.config.name in PER_TOKEN_BACKENDS:
self.logger.info("\t+ Initializing Per-Token Latency tracker")
self.per_token_latency_tracker = PerTokenLatencySessionTrackerLogitsProcessor(
device=self.backend.config.device, backend=self.backend.config.name
)
self.config.generate_kwargs["logits_processor"] = LogitsProcessorList([self.per_token_latency_tracker])
elif self.backend.config.task in IMAGE_DIFFUSION_TASKS:
self.logger.info("\t+ Initializing Diffusion Step Latency tracker")
self.per_step_latency_tracker = PerStepLatencySessionTrackerPipelineCallback(
device=self.backend.config.device, backend=self.backend.config.name
)
self.config.call_kwargs["callback_on_step_end"] = self.per_step_latency_tracker
if self.config.memory:
self.logger.info("\t+ Initializing Memory tracker")
self.memory_tracker = MemoryTracker(
backend=self.backend.config.name,
device=self.backend.config.device,
device_ids=self.backend.config.device_ids,
)
if self.config.energy:
self.logger.info("\t+ Initializing Energy tracker")
self.energy_tracker = EnergyTracker(
backend=self.backend.config.name,
device=self.backend.config.device,
device_ids=self.backend.config.device_ids,
)
self.logger.info(f"\t+ Generating inputs for task {self.backend.config.task}")
self.inputs = InputGenerator(
task=self.backend.config.task,
model_shapes=self.backend.model_shapes,
model_type=self.backend.config.model_type,
input_shapes=self.config.input_shapes,
)()
self.run_model_loading_tracking()
self.logger.info(f"\t+ Preparing inputs for backend {self.backend.config.name}")
self.inputs = self.backend.prepare_inputs(inputs=self.inputs)
if self.config.warmup_runs > 0:
if self.backend.config.task in TEXT_GENERATION_TASKS:
self.warmup_text_generation()
elif self.backend.config.task in IMAGE_DIFFUSION_TASKS:
self.warmup_image_diffusion()
else:
self.warmup_inference()
if self.config.latency:
if self.backend.config.task in TEXT_GENERATION_TASKS:
if self.backend.config.name in PER_TOKEN_BACKENDS:
self.run_per_token_text_generation_latency_tracking()
else:
self.run_text_generation_latency_tracking()
elif self.backend.config.task in IMAGE_DIFFUSION_TASKS:
self.run_image_diffusion_latency_tracking()
else:
self.run_inference_latency_tracking()
if self.config.memory:
if self.backend.config.task in TEXT_GENERATION_TASKS:
self.run_text_generation_memory_tracking()
elif self.backend.config.task in IMAGE_DIFFUSION_TASKS:
self.run_image_diffusion_memory_tracking()
else:
self.run_inference_memory_tracking()
if self.config.energy:
if self.backend.config.task in TEXT_GENERATION_TASKS:
self.run_text_generation_energy_tracking()
elif self.backend.config.task in IMAGE_DIFFUSION_TASKS:
self.run_image_diffusion_energy_tracking()
else:
self.run_inference_energy_tracking()
return self.report