in text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py [0:0]
def select(self, input_ids: torch.tensor, logits: jnp.ndarray) -> jnp.ndarray:
"""Select the next tokens from the candidate logits.
Args:
input_ids (`torch.tensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation (not used in all generation modes).
logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
The logits corresponding to the generated tokens.
Return:
`jnp.ndarray`: A `jnp.ndarray` containing the selected tokens.
"""
# Logits processors is written in pytorch, so parameters are cast to float32 and converted to pytorch and back
# to jax with j2t/t2j (that is a bit expensive, it does copies), otherwise some operations are not supported.
logits_pt = torch_xla2.tensor.j2t(logits.astype(jnp.float32))
scores = self.logits_processor(input_ids, logits_pt)
scores = torch_xla2.tensor.t2j(scores).to_device(logits.device)
if self.mode == GenerationMode.SAMPLE:
# split the key to avoid reusing the same key for multiple samples
subkey, self.key = jax.random.split(self.key)
return self._sample(scores, subkey)
else:
return jnp.argmax(scores, axis=-1)