def select()

in text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py [0:0]


    def select(self, input_ids: torch.tensor, logits: jnp.ndarray) -> jnp.ndarray:
        """Select the next tokens from the candidate logits.

        Args:
            input_ids (`torch.tensor` of shape `(batch_size, sequence_length)`):
                The sequence used as a prompt for the generation (not used in all generation modes).
            logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
                The logits corresponding to the generated tokens.

        Return:
            `jnp.ndarray`: A `jnp.ndarray` containing the selected tokens.
        """
        # Logits processors is written in pytorch, so parameters are cast to float32 and  converted to pytorch and back
        # to jax with j2t/t2j (that is a bit expensive, it does copies), otherwise some operations are not supported.
        logits_pt = torch_xla2.tensor.j2t(logits.astype(jnp.float32))
        scores = self.logits_processor(input_ids, logits_pt)
        scores = torch_xla2.tensor.t2j(scores).to_device(logits.device)

        if self.mode == GenerationMode.SAMPLE:
            # split the key to avoid reusing the same key for multiple samples
            subkey, self.key = jax.random.split(self.key)
            return self._sample(scores, subkey)
        else:
            return jnp.argmax(scores, axis=-1)