in core/src/queue.rs [97:211]
fn queue_blocking_task(
padded_model: bool,
max_batch_tokens: usize,
max_batch_requests: Option<usize>,
max_concurrent_requests: usize,
mut queue_receiver: mpsc::Receiver<QueueCommand>,
) {
let capacity = max_batch_requests.unwrap_or(max_concurrent_requests);
let mut entries: VecDeque<Entry> = VecDeque::with_capacity(max_concurrent_requests);
while let Some(cmd) = queue_receiver.blocking_recv() {
match cmd {
QueueCommand::Append(entry, span) => {
let _span = span.entered();
entries.push_back(*entry);
let gauge = metrics::gauge!("te_queue_size");
gauge.increment(1.0);
}
QueueCommand::NextBatch {
response_sender,
span,
} => {
let _span = span.entered();
let mut input_ids = Vec::with_capacity(max_batch_tokens);
let mut token_type_ids = Vec::with_capacity(max_batch_tokens);
let mut position_ids = Vec::with_capacity(max_batch_tokens);
let mut pooled_indices = Vec::with_capacity(capacity);
let mut raw_indices = Vec::with_capacity(capacity);
let mut metadata = Vec::with_capacity(capacity);
let mut cu_seq_lengths = Vec::with_capacity(capacity);
cu_seq_lengths.push(0);
let mut current_tokens = 0;
let mut max_length = 0;
let mut entry_index = 0;
while let Some(entry) = entries.pop_front() {
// Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client)
if entry.metadata.response_tx.is_closed() {
let counter = metrics::counter!("te_request_failure", "err" => "dropped");
counter.increment(1);
continue;
}
let entry_tokens = entry.encoding.input_ids.len();
let total_tokens = if padded_model {
(max(max_length, entry_tokens as u32) * (metadata.len() + 1) as u32)
as usize
} else {
current_tokens + entry_tokens
};
if total_tokens > max_batch_tokens {
entries.push_front(entry);
break;
}
match entry.metadata.pooling {
true => pooled_indices.push(entry_index),
false => raw_indices.push(entry_index),
}
max_length = max(max_length, entry_tokens as u32);
input_ids.extend(entry.encoding.input_ids);
token_type_ids.extend(entry.encoding.token_type_ids);
position_ids.extend(entry.encoding.position_ids);
current_tokens += entry_tokens;
metadata.push(entry.metadata);
cu_seq_lengths.push(current_tokens as u32);
entry_index += 1;
if Some(metadata.len()) == max_batch_requests {
break;
}
}
let batch_size = metadata.len();
let next_batch = if metadata.is_empty() {
None
} else {
Some((
metadata,
Batch {
input_ids,
token_type_ids,
position_ids,
cumulative_seq_lengths: cu_seq_lengths,
max_length,
pooled_indices,
raw_indices,
},
))
};
let _ = response_sender.send(next_batch);
let histogram = metrics::histogram!("te_batch_next_size");
histogram.record(batch_size as f64);
let histogram = metrics::histogram!("te_batch_next_tokens");
histogram.record(current_tokens as f64);
let gauge = metrics::gauge!("te_queue_size");
gauge.set(entries.len() as f64)
}
}
}
}