def warmup()

in text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py [0:0]


    def warmup(self, batch: Batch) -> int:
        """Verify if the hardware can support the target load.

        Args:
            batch (`Batch`):
                A batch corresponding to the maximum number of concurrent requests.

        Return:
            The maximum number of tokens the model supports.
        """
        logger.debug("Warming up the model")
        start = time.time()
        # Just check that the warmup request parameters match the model capacity
        batch_size = self.engine.env.batch_size
        if len(batch.requests) > batch_size:
            raise ValueError(
                f"Inconsistent server configuration: please make sure max-prefill-tokens does not exceed {batch_size} x max-input-length."
            )

        # Counter-intuitively, now we ignore the input batch. Instead, we create dummy batches to cover all possible
        # batch sizes and sequence lengths.
        seq_len = self.model.config.sequence_length
        if os.environ.get("SKIP_WARMUP", "0") == "1":
            logger.debug("Skipping warmup")
            return batch_size * seq_len
        bucket_seq_len = take_nearest_length(DEFAULT_PREFILL_BUCKETS, self.engine.max_prefill_length)
        decode_done = False
        for l in reversed(DEFAULT_PREFILL_BUCKETS):
            # Skip all the unsupported lengths
            if l > bucket_seq_len:
                continue
            # create a dummy request with the current sequence length -1 (so it gets padded up to l)
            dummy_request = self._create_dummy_request(l - 1)
            # We define few max_new_tokens to request at least one (by prefill) and another by decode.
            MAX_NEW_TOKENS = 10
            dummy_request.stopping_parameters.max_new_tokens = MAX_NEW_TOKENS
            warmup_batch = Batch(id=0,
                                    requests=[dummy_request],
                                    size=1,
                                    max_tokens=batch.max_tokens)
            logger.debug(f"Warmup for requests, len {l} seq_len {seq_len}")
            _generations, next_batch = self.prefill(warmup_batch)
            if next_batch is not None:
                self.decode([next_batch])
                decode_done = True
            self.clear()
        if not decode_done:
            logger.debug("No decode done during warmup")

        elapsed = time.time() - start
        logger.debug(f"Warmup done, took {elapsed:.2f}s")
        seq_len = self.engine.env.seq_len
        return batch_size * seq_len