in text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py [0:0]
def warmup(self, batch: Batch) -> int:
"""Verify if the hardware can support the target load.
Args:
batch (`Batch`):
A batch corresponding to the maximum number of concurrent requests.
Return:
The maximum number of tokens the model supports.
"""
logger.debug("Warming up the model")
start = time.time()
# Just check that the warmup request parameters match the model capacity
batch_size = self.engine.env.batch_size
if len(batch.requests) > batch_size:
raise ValueError(
f"Inconsistent server configuration: please make sure max-prefill-tokens does not exceed {batch_size} x max-input-length."
)
# Counter-intuitively, now we ignore the input batch. Instead, we create dummy batches to cover all possible
# batch sizes and sequence lengths.
seq_len = self.model.config.sequence_length
if os.environ.get("SKIP_WARMUP", "0") == "1":
logger.debug("Skipping warmup")
return batch_size * seq_len
bucket_seq_len = take_nearest_length(DEFAULT_PREFILL_BUCKETS, self.engine.max_prefill_length)
decode_done = False
for l in reversed(DEFAULT_PREFILL_BUCKETS):
# Skip all the unsupported lengths
if l > bucket_seq_len:
continue
# create a dummy request with the current sequence length -1 (so it gets padded up to l)
dummy_request = self._create_dummy_request(l - 1)
# We define few max_new_tokens to request at least one (by prefill) and another by decode.
MAX_NEW_TOKENS = 10
dummy_request.stopping_parameters.max_new_tokens = MAX_NEW_TOKENS
warmup_batch = Batch(id=0,
requests=[dummy_request],
size=1,
max_tokens=batch.max_tokens)
logger.debug(f"Warmup for requests, len {l} seq_len {seq_len}")
_generations, next_batch = self.prefill(warmup_batch)
if next_batch is not None:
self.decode([next_batch])
decode_done = True
self.clear()
if not decode_done:
logger.debug("No decode done during warmup")
elapsed = time.time() - start
logger.debug(f"Warmup done, took {elapsed:.2f}s")
seq_len = self.engine.env.seq_len
return batch_size * seq_len