in server/text_generation_server/models/flash_causal_lm.py [0:0]
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
# We assume that if len(requests) == len(self) then the requests are the same
if len(request_ids) == len(self):
return self
device = self.block_tables_tensor.device
# New values after filtering
requests_idx_mapping = {}
# Used to index into tensors
indices = []
if not has_triton():
# slots to keep after filtering
slot_filtering_indices = torch.zeros(
self.slots.shape[0], dtype=torch.bool, device=device
)
# Create on CPU to only move to GPU once instead of at every copy
slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
max_input_length = 0
max_current_length = 0
requests = []
block_tables = []
all_input_ids = []
input_ids = []
prompt_lengths = []
input_lengths = []
cache_lengths = []
prefix_offsets = []
read_offsets = []
cu_slots = [0]
prefilling_mask = []
prefill_logprob_tokens = []
stopping_criterias = []
top_n_tokens = []
adapter_set = set()
num_blocks = 0
max_blocks = 0
max_slots = 0
cumulative_slot_tokens = 0
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
indices.append(idx)
requests_idx_mapping[request_id] = i
requests.append(self.requests[idx])
# Prefilling
request_prefilling = self.prefilling_mask[idx]
prefilling_mask.append(request_prefilling)
# Get length
request_input_length = self.input_lengths[idx]
request_cache_length = self.cache_lengths[idx]
max_input_length = max(max_input_length, request_input_length)
max_current_length = max(
max_current_length, request_cache_length + request_input_length
)
all_input_ids.append(self.all_input_ids[idx])
prompt_lengths.append(self.prompt_lengths[idx])
input_lengths.append(request_input_length)
cache_lengths.append(request_cache_length)
prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
prefill_logprob_tokens.append(self.prefill_logprob_tokens[idx])
ADAPTER_TO_INDEX = get_adapter_to_index()
adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
adapter_set.add(adapter_index)
request_block_table = self.block_tables[idx]
num_blocks += len(request_block_table)
block_tables.append(request_block_table)
start_slot = self.cu_slots[idx]
end_slot = self.cu_slots[idx + 1]
slot_length = end_slot - start_slot
if not has_triton():
# Set slice
slot_filtering_indices[start_slot:end_slot] = True
cu_slots.append(cumulative_slot_tokens + slot_length)
# Input ids if the request was part of a prefilling batch
# If the batch was decoding we can index into the tensor directly later
if self.prefilling:
input_ids.append(self.input_ids[idx])
else:
# Copy to tensor (CPU)
slot_indices[i] = cumulative_slot_tokens + request_cache_length
cumulative_slot_tokens += slot_length
max_blocks = max(max_blocks, len(request_block_table))
max_slots = max(max_slots, slot_length)
all_input_ids_tensor = self.all_input_ids_tensor[indices]
block_tables_tensor = self.block_tables_tensor[indices]
next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
speculative_ids = (
self.speculative_ids[indices] if self.speculative_ids is not None else None
)
prompt_lengths_tensor = self.prompt_lengths_tensor[indices]
cu_slots = torch.tensor(cu_slots, dtype=torch.int64)
if not has_triton():
slots = self.slots[slot_filtering_indices]
else:
slots = self.slots.new_empty(cumulative_slot_tokens)
gpu_cu_slots = cu_slots.to(device)
slots_indexing_start = self.cu_slots.to(device)[indices]
slots_filtering(
max_slots, self.slots, slots, gpu_cu_slots, slots_indexing_start
)
if self.prefilling:
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
position_ids = None
slot_indices = None
cache_lengths_tensor = None
input_lengths_tensor = None
adapter_meta = None
else:
# Index into tensors
input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices]
adapter_indices = self.adapter_meta.adapter_indices[indices]
input_lengths_tensor = self.input_lengths_tensor[indices]
cache_lengths_tensor = self.cache_lengths_tensor[indices]
# Move to GPU now that we have the whole tensor
slot_indices = slot_indices.to(device)
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
adapter_segments = torch.tensor(
adapter_segments, dtype=torch.int32, device=device
)
adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
)
return type(self)(
batch_id=self.batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
prefill_cache_indices=None,
slot_indices=slot_indices,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor,
slots=slots,
cu_slots=cu_slots,
max_input_length=max_input_length,
max_current_length=max_current_length,
prefilling=self.prefilling,
prefilling_mask=prefilling_mask,
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
prefill_logprob_tokens=prefill_logprob_tokens,
prompt_lengths=prompt_lengths,
prompt_lengths_tensor=prompt_lengths_tensor,
input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
cache_lengths=cache_lengths,
cache_lengths_tensor=cache_lengths_tensor,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
num_blocks=num_blocks,
max_blocks=max_blocks,
speculative_ids=speculative_ids,
adapter_meta=adapter_meta,
)