in text-generation-inference/server/text_generation_server/generator.py [0:0]
def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]:
"""Decode the specified prefilled requests.
Args:
batches (`List[CachedBatch]`):
A list of previous batches containing the prefilled requests.
Return:
A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
"""
# batches contains a list composed of:
# - the batch id returned by the last decode,
# - the batch id(s) returned by the last prefill(s)
# Batches are always concatenated during prefill, so we can
# just carry on with decoding. We adopt the id of the first
# batch in the list as our next batch id.
next_batch_id = batches[0].id
request_ids = []
for batch in batches:
request_ids += batch.request_ids
cleared_request_ids = []
for slot in self.slots:
if slot.state == slot.State.READY and slot.request_id not in request_ids:
cleared_request_ids.append(slot.request_id)
slot.clear()
if len(cleared_request_ids) > 0:
logger.info(f"Clearing slot for requests {cleared_request_ids} as they are not requested.")
active_slots = [slot for slot in self.slots if slot.state == slot.State.READY]
if len(active_slots) < len(request_ids):
logger.error("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)")
# Reconstruct input_ids and attention_mask from slots
input_ids = None
attention_mask = None
batch_size = len(self.slots)
position_ids = torch.zeros(
[batch_size, 1],
dtype=torch.int64,
)
# init pad_token_id and input_ids
pad_token_id = self.tokenizer.pad_token_id
if pad_token_id is None:
if isinstance(self.tokenizer.eos_token_id, list):
pad_token_id = self.tokenizer.eos_token_id[0]
else:
pad_token_id = self.tokenizer.eos_token_id
# Create blank inputs covering all slots (even empty ones)
input_ids = torch.full(
[batch_size, 1],
fill_value=pad_token_id,
dtype=torch.int64,
)
cache_position = torch.zeros([1], dtype=torch.int64)
for i, slot in enumerate(self.slots):
if slot.state != Slot.State.EMPTY:
# input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached)
input_ids.index_put_([torch.tensor([i])], slot.next_token)
if not self._supports_static_cache:
# When using dynamic cache, the whole attention mask needs to be passed over to the model at each iteration.
if attention_mask is None:
# Create default mask covering all slots (even empty ones)
attention_mask = torch.zeros(
[batch_size, slot.attention_mask.size(-1)],
dtype=torch.int64,
)
attention_mask.index_put_([torch.tensor([i])], slot.attention_mask)
position_ids.index_put_([torch.tensor([i])], torch.tensor(slot.position_id))
cache_position = torch.maximum(cache_position, torch.tensor([slot.cache_position]))
if input_ids is None:
raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)")
extra_args = {}
if self._supports_static_cache:
extra_args["cache_position"] = position_ids.max().unsqueeze(0).to(self.model.device)
else:
extra_args["attention_mask"] = attention_mask.to(self.model.device)
extra_args["past_key_values"] = self.past_key_values
generations, next_batch = self._generate_token(
next_batch_id,
input_ids.to(self.model.device),
self.model_one_token,
position_ids=position_ids.to(self.model.device),
**extra_args,
)
for slot, gen in zip(self.slots, generations):
slot.position_id += len(gen.tokens.ids)
return generations, next_batch