in text-generation-inference/server/text_generation_server/generator.py [0:0]
def warmup(self, batch: Batch) -> int:
"""Verify if the hardware can support the target load.
Args:
batch (`Batch`):
A batch corresponding to the maximum number of concurrent requests.
Return:
The maximum number of tokens the model supports.
"""
logger.debug("Warming up the model")
start = time.time()
# Just check that the warmup request parameters match the model capacity
# NOTE: later self.model.config.batch_size might become self.model.config.max_batch_size.
if self.model.config.batch_size is not None:
batch_size = self.model.config.batch_size
else:
# batch size is not set, just assume it's unlimited and accept all requests
batch_size = len(batch.requests)
if len(batch.requests) > batch_size:
raise ValueError(
f"Inconsistent server configuration: please make sure max-prefill-tokens does not exceed {batch_size} x max-input-length."
)
# Counter-intuitively, now we ignore the input batch. Instead, we create dummy batches to cover all possible
# batch sizes and sequence lengths.
seq_len = self.model.config.sequence_length
if os.environ.get("SKIP_WARMUP", "0") == "1":
logger.debug("Skipping warmup")
return batch_size * seq_len
bucket_seq_len = take_nearest_length(seq_len)
requests = [self._create_dummy_request(seq_len) for _ in range(batch_size)]
for _ in reversed(range(batch_size)):
# Prefill with different truncate sizes to test all prefill lengths. List is reversed so first longest
# sequences are tested and, if there is a memory failure, that will appear sooner.
for l in reversed(PREFILL_LENGTHS):
# Skip all the unsupported lengths
if l > bucket_seq_len:
continue
# Set all truncate values for all requests
for r in requests:
r.truncate = l
r.stopping_parameters.max_new_tokens = 10
warmup_batch = Batch(id=0,
requests=requests,
size=len(requests),
max_tokens=batch.max_tokens)
logger.debug(f"Warmup for {len(requests)} requests, truncate value {l} seq_len {seq_len}")
_generations, next_batch = self.prefill(warmup_batch)
if next_batch is not None:
self.decode([next_batch])
else:
logger.debug(f"No decode on warmup for {len(requests)}x{l}")
self.clear()
# remove the last requests to decrease the batch size
requests.pop()
elapsed = time.time() - start
logger.debug(f"Warmup done, took {elapsed:.2f}s")
return batch_size * seq_len