def run_text_generation_tracking()

in optimum_benchmark/scenarios/energy_star/scenario.py [0:0]


    def run_text_generation_tracking(self):
        self.logger.info("\t+ Running Text Generation tracking")

        prefill_kwargs = {**self.config.generate_kwargs, **TEXT_GENERATION_PREFILL_OVERRIDES}

        with self.track(task_name="prefill"):
            for i in tqdm(range(0, self.config.num_samples, self.config.input_shapes["batch_size"])):
                inputs = self.backend.prepare_inputs(self.dataset[i : i + self.config.input_shapes["batch_size"]])
                self.backend.prefill(inputs, prefill_kwargs)

        if self.config.energy:
            prefill_energy = self.energy_tracker.get_energy()

            self.report.prefill.energy = prefill_energy
            self.report.prefill.efficiency = Efficiency.from_energy(
                prefill_energy, self.dataset_prefill_volume, unit=PREFILL_EFFICIENCY_UNIT
            )
        if self.config.latency:
            prefill_latency = self.latency_tracker.get_latency()

            self.report.prefill.latency = prefill_latency
            self.report.prefill.throughput = Throughput.from_latency(
                prefill_latency, self.dataset_prefill_volume, unit=PREFILL_THROUGHPUT_UNIT
            )
        if self.config.memory:
            self.report.prefill.memory = self.memory_tracker.get_max_memory()

        with self.track(task_name="generate"):
            for i in tqdm(range(0, self.config.num_samples, self.config.input_shapes["batch_size"])):
                inputs = self.backend.prepare_inputs(self.dataset[i : i + self.config.input_shapes["batch_size"]])
                self.backend.generate(inputs, self.config.generate_kwargs)

        if self.config.energy:
            generate_energy = self.energy_tracker.get_energy()
            decode_energy = generate_energy - prefill_energy

            self.report.decode.energy = decode_energy
            self.report.decode.efficiency = Efficiency.from_energy(
                decode_energy, self.dataset_decode_volume, unit=DECODE_EFFICIENCY_UNIT
            )
        if self.config.latency:
            generate_latency = self.latency_tracker.get_latency()
            decode_latency = generate_latency - prefill_latency

            self.report.decode.latency = decode_latency
            self.report.decode.throughput = Throughput.from_latency(
                decode_latency, self.dataset_decode_volume, unit=DECODE_THROUGHPUT_UNIT
            )
        if self.config.memory:
            self.report.decode.memory = self.memory_tracker.get_max_memory()