backends/gaudi/server/text_generation_server/utils/tokens.py [272:441]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            else None
        )

        self.repetition_processor = (
            HeterogeneousRepetitionPenaltyLogitsProcessor(
                repetition_penalty, dtype, device
            )
            if any([x != 1.0 for x in repetition_penalty])
            else None
        )

        self.frequency_processor = (
            HeterogeneousFrequencyPenaltyLogitsProcessor(
                frequency_penalty, dtype, device
            )
            if any([x != 0.0 for x in frequency_penalty])
            else None
        )

        self.grammar_processor = (
            HeterogeneousGrammarLogitProcessor(
                tokenizer, device, grammars, grammar_types
            )
            if any([grammar != "" for grammar in grammars])
            else None
        )

        if any(x != 1.0 for x in temperature):
            do_sample = [
                sample or x != 1.0 for x, sample in zip(temperature, do_sample)
            ]
            warpers.append(
                HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
            )

        if any(x != 0 for x in top_k):
            do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
            warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))

        if any(x < 1.0 for x in top_p):
            do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
            warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))

        if any(x < 1.0 for x in typical_p):
            do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
            warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))

        self.warpers = warpers

        if any(do_sample):
            self.choice = HeterogeneousSampling(do_sample, seeds, device)
        else:
            self.choice = Greedy()

        self.seeds = seeds
        self.do_sample = do_sample
        self.dtype = dtype
        self.device = device
        self.tokenizer = tokenizer
        self.fsm_grammar_states = fsm_grammar_states
        self.grammars = grammars
        self.grammar_types = grammar_types

    def __call__(
        self,
        input_ids: torch.Tensor,
        scores: torch.Tensor,
        speculate: int,
        speculated_ids: Optional[torch.Tensor] = None,
        speculative_scores: Optional[torch.Tensor] = None,
        verbose=False,
    ):
        if speculated_ids is not None:
            B = scores.shape[0] // (speculated_ids.shape[1] + 1)
            S = speculated_ids.shape[1] + 1
            scores = scores.view(B, S, -1)
        else:
            B = scores.shape[0]
            S = 1
            scores = scores.view(B, S, -1)

        next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)

        for j in range(S):
            _scores = scores[:, j]
            if self.watermark_processor is not None:
                _scores = self.watermark_processor(input_ids, _scores)
            if self.repetition_processor is not None:
                _scores = self.repetition_processor(input_ids, _scores)
            if self.frequency_processor is not None:
                _scores = self.frequency_processor(input_ids, _scores)
            if self.grammar_processor is not None:
                _scores = self.grammar_processor(_scores, self.fsm_grammar_states)
            for warper in self.warpers:
                _scores = warper(input_ids, _scores)
            _next_ids = self.choice(_scores)
            scores[:, j] = _scores
            next_ids[:, j] = _next_ids
        next_ids = next_ids.view(B * S)
        allscores = scores.view(B * S, -1)
        alllogprobs = torch.log_softmax(allscores, -1)

        if speculated_ids is not None:
            accepted_ids = []
            B = next_ids.shape[0] // (speculated_ids.shape[1] + 1)
            S = speculated_ids.shape[1] + 1
            indices = []
            for i in range(B):
                _next_ids = next_ids[i * S : (i + 1) * S]
                _speculated_ids = speculated_ids[i]
                validate_speculative = _next_ids[:-1] == _speculated_ids
                index = i * S
                accepted = 1
                # First is always valid
                indices.append(index)
                for valid in validate_speculative.tolist():
                    if valid:
                        index += 1
                        accepted += 1
                        indices.append(index)
                    else:
                        break
                accepted_ids.append(accepted)

            accepted_ids = torch.tensor(
                accepted_ids, device=input_ids.device, dtype=input_ids.dtype
            )
            next_ids = next_ids[indices]
            logprobs = alllogprobs[indices]
            indices = torch.arange(B, device=input_ids.device) * S
            if speculative_scores is not None:
                speculative_scores = speculative_scores[indices + accepted_ids - 1]
        else:
            accepted_ids = torch.ones_like(next_ids)
            logprobs = alllogprobs

        next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)

        if speculate > 0:
            if speculative_scores is not None:
                # Medusa provided some scores
                speculative_ids = Greedy()(speculative_scores)
            else:
                # n-gram
                speculative_ids = create_n_gram_speculation(
                    input_ids, next_ids, accepted_ids, speculate, verbose
                )
        else:
            speculative_ids = None

        return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids

    def advance_grammar(self, next_ids: List[int]):
        if self.grammar_processor is not None:
            other_new_states = self.grammar_processor.advance_batch(
                next_ids, self.fsm_grammar_states
            )
            self.fsm_grammar_states = other_new_states
        return self

    def advance_grammar_single(self, grammar_state_index: int, next_id: int):
        if self.grammar_processor is not None:
            self.fsm_grammar_states[grammar_state_index] = (
                self.grammar_processor.advance_at_index(
                    next_id,
                    self.fsm_grammar_states[grammar_state_index],
                    grammar_state_index,
                )
            )
        return self
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



server/text_generation_server/utils/tokens.py [263:432]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            else None
        )

        self.repetition_processor = (
            HeterogeneousRepetitionPenaltyLogitsProcessor(
                repetition_penalty, dtype, device
            )
            if any([x != 1.0 for x in repetition_penalty])
            else None
        )

        self.frequency_processor = (
            HeterogeneousFrequencyPenaltyLogitsProcessor(
                frequency_penalty, dtype, device
            )
            if any([x != 0.0 for x in frequency_penalty])
            else None
        )

        self.grammar_processor = (
            HeterogeneousGrammarLogitProcessor(
                tokenizer, device, grammars, grammar_types
            )
            if any([grammar != "" for grammar in grammars])
            else None
        )

        if any(x != 1.0 for x in temperature):
            do_sample = [
                sample or x != 1.0 for x, sample in zip(temperature, do_sample)
            ]
            warpers.append(
                HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
            )

        if any(x != 0 for x in top_k):
            do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
            warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))

        if any(x < 1.0 for x in top_p):
            do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
            warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))

        if any(x < 1.0 for x in typical_p):
            do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
            warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))

        self.warpers = warpers

        if any(do_sample):
            self.choice = HeterogeneousSampling(do_sample, seeds, device)
        else:
            self.choice = Greedy()

        self.seeds = seeds
        self.do_sample = do_sample
        self.dtype = dtype
        self.device = device
        self.tokenizer = tokenizer
        self.fsm_grammar_states = fsm_grammar_states
        self.grammars = grammars
        self.grammar_types = grammar_types

    def __call__(
        self,
        input_ids: torch.Tensor,
        scores: torch.Tensor,
        speculate: int,
        speculated_ids: Optional[torch.Tensor] = None,
        speculative_scores: Optional[torch.Tensor] = None,
        verbose=False,
    ):
        if speculated_ids is not None:
            B = scores.shape[0] // (speculated_ids.shape[1] + 1)
            S = speculated_ids.shape[1] + 1
            scores = scores.view(B, S, -1)
        else:
            B = scores.shape[0]
            S = 1
            scores = scores.view(B, S, -1)

        next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)

        for j in range(S):
            _scores = scores[:, j]
            if self.watermark_processor is not None:
                _scores = self.watermark_processor(input_ids, _scores)
            if self.repetition_processor is not None:
                _scores = self.repetition_processor(input_ids, _scores)
            if self.frequency_processor is not None:
                _scores = self.frequency_processor(input_ids, _scores)
            if self.grammar_processor is not None:
                _scores = self.grammar_processor(_scores, self.fsm_grammar_states)
            for warper in self.warpers:
                _scores = warper(input_ids, _scores)
            _next_ids = self.choice(_scores)
            scores[:, j] = _scores
            next_ids[:, j] = _next_ids
        next_ids = next_ids.view(B * S)
        allscores = scores.view(B * S, -1)
        alllogprobs = torch.log_softmax(allscores, -1)

        if speculated_ids is not None:
            accepted_ids = []
            B = next_ids.shape[0] // (speculated_ids.shape[1] + 1)
            S = speculated_ids.shape[1] + 1
            indices = []
            for i in range(B):
                _next_ids = next_ids[i * S : (i + 1) * S]
                _speculated_ids = speculated_ids[i]
                validate_speculative = _next_ids[:-1] == _speculated_ids
                index = i * S
                accepted = 1
                # First is always valid
                indices.append(index)
                for valid in validate_speculative.tolist():
                    if valid:
                        index += 1
                        accepted += 1
                        indices.append(index)
                    else:
                        break
                accepted_ids.append(accepted)

            accepted_ids = torch.tensor(
                accepted_ids, device=input_ids.device, dtype=input_ids.dtype
            )
            next_ids = next_ids[indices]
            logprobs = alllogprobs[indices]
            indices = torch.arange(B, device=input_ids.device) * S
            if speculative_scores is not None:
                speculative_scores = speculative_scores[indices + accepted_ids - 1]
        else:
            accepted_ids = torch.ones_like(next_ids)
            logprobs = alllogprobs

        next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)

        if speculate > 0:
            if speculative_scores is not None:
                # Medusa provided some scores
                speculative_ids = Greedy()(speculative_scores)
            else:
                # n-gram
                speculative_ids = create_n_gram_speculation(
                    input_ids, next_ids, accepted_ids, speculate, verbose
                )
        else:
            speculative_ids = None

        return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids

    def advance_grammar(self, next_ids: List[int]):
        if self.grammar_processor is not None:
            other_new_states = self.grammar_processor.advance_batch(
                next_ids, self.fsm_grammar_states
            )
            self.fsm_grammar_states = other_new_states
        return self

    def advance_grammar_single(self, grammar_state_index: int, next_id: int):
        if self.grammar_processor is not None:
            self.fsm_grammar_states[grammar_state_index] = (
                self.grammar_processor.advance_at_index(
                    next_id,
                    self.fsm_grammar_states[grammar_state_index],
                    grammar_state_index,
                )
            )
        return self
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



