in server/text_generation_server/models/idefics_causal_lm.py [0:0]
def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]:
# It deletes requests from the batch. For instance when client lost connection
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
if len(request_ids) == len(self):
return self
keep_indices = []
# New values after filtering
requests_idx_mapping = {}
requests = []
input_lengths = []
prefix_offsets = []
read_offsets = []
all_input_ids = []
max_input_length = 0
next_token_choosers = []
stopping_criterias = []
total_remaining_decode_tokens = 0
new_padding_right_offset = 0
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
requests_idx_mapping[request_id] = i
keep_indices.append(idx)
requests.append(self.requests[idx])
prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx])
all_input_ids.append(self.all_input_ids[idx])
request_input_length = self.input_lengths[idx]
input_lengths.append(request_input_length)
max_input_length = max(max_input_length, request_input_length)
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
total_remaining_decode_tokens += remaining_decode_tokens
new_padding_right_offset = max(
new_padding_right_offset, remaining_decode_tokens
)
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
input_ids = self.input_ids[keep_indices]
position_ids = self.position_ids[keep_indices]
self.attention_mask = self.attention_mask[
keep_indices,
-(self.padding_right_offset + max_input_length) : (
self.attention_mask.shape[1] - self.padding_right_offset
)
+ new_padding_right_offset,
]
# Do the same for pixel_values and image_attention_mask
pixel_values = self.pixel_values[keep_indices]
self.image_attention_mask = self.image_attention_mask[
keep_indices,
-(self.padding_right_offset + max_input_length) : (
self.image_attention_mask.shape[1] - self.padding_right_offset
)
+ new_padding_right_offset,
:,
]
if self.image_hidden_states is None:
image_hidden_states = None
else:
image_hidden_states = self.image_hidden_states[keep_indices]
# Ensure that past_key_values tensors can be updated in-place
if type(self.past_key_values[0]) is tuple:
self.past_key_values = [list(layer) for layer in self.past_key_values]
# Update tensors in-place to allow incremental garbage collection
past_kv_length = max_input_length - 1
for layer in self.past_key_values:
past_keys, past_values = layer
if len(past_keys.shape) == 3:
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
if self.keys_head_dim_last:
layer[0] = past_keys[keep_indices, :, -past_kv_length:, :]
else:
layer[0] = past_keys[keep_indices, :, :, -past_kv_length:]
del past_keys
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
del past_values
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
self.requests = requests
self.requests_idx_mapping = requests_idx_mapping
self.input_ids = input_ids
self.pixel_values = pixel_values
self.image_hidden_states = image_hidden_states
self.position_ids = position_ids
self.all_input_ids = all_input_ids
self.input_lengths = input_lengths
self.prefix_offsets = prefix_offsets
self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias
self.max_input_length = max_input_length
self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens
return self