optimum/graphcore/generation/logits_process.py (187 lines of code) (raw):
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import poptorch
import torch
from transformers.generation.logits_process import (
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
ForceTokensLogitsProcessor,
MinLengthLogitsProcessor,
NoRepeatNGramLogitsProcessor,
SuppressTokensAtBeginLogitsProcessor,
SuppressTokensLogitsProcessor,
WhisperTimeStampLogitsProcessor,
)
VERY_LARGE_NEGATIVE_CONST = -1e18
class IPUForcedBOSTokenLogitsProcessor(ForcedBOSTokenLogitsProcessor):
@classmethod
def from_model(cls, inst, vocab_size: int):
self = inst
self.bos_scores = VERY_LARGE_NEGATIVE_CONST * torch.ones((1, vocab_size), dtype=torch.int32)
self.bos_scores[:, self.bos_token_id] = 0
self.__class__ = cls
return self
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, absolute_step: torch.IntTensor
) -> torch.FloatTensor:
cond = (absolute_step > 1).int()
return cond * scores + (1 - cond) * self.bos_scores.to(device=scores.device)
class IPUForcedEOSTokenLogitsProcessor(ForcedEOSTokenLogitsProcessor):
@classmethod
def from_model(cls, inst, vocab_size: int):
self = inst
self.eos_scores = VERY_LARGE_NEGATIVE_CONST * torch.ones((1, vocab_size), dtype=torch.int32)
self.eos_scores[:, self.eos_token_id] = 0
self.__class__ = cls
return self
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, absolute_step: torch.IntTensor
) -> torch.FloatTensor:
cond = (absolute_step < self.max_length).int()
return cond * scores + (1 - cond) * self.eos_scores.to(device=scores.device)
class IPUMinLengthLogitsProcessor(MinLengthLogitsProcessor):
@classmethod
def from_model(cls, inst, vocab_size: int):
self = inst
self.__class__ = cls
self.mask = torch.ones((1, vocab_size), dtype=torch.int32)
self.mask[:, self.eos_token_id] = 0
return self
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, absolute_step: torch.IntTensor
) -> torch.FloatTensor:
mask = self.mask.to(scores.device)
cond = absolute_step >= self.min_length
mask = mask | cond
return mask * scores + (1 - mask) * VERY_LARGE_NEGATIVE_CONST
class IPUNoRepeatNGramLogitsProcessor(NoRepeatNGramLogitsProcessor):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, absolute_step: torch.IntTensor
) -> torch.FloatTensor:
# mask out values above cur_len
cur_len = absolute_step
cur_len_mask = torch.arange(0, input_ids.shape[-1], device=input_ids.device).unsqueeze(0) < cur_len
input_ids = input_ids.view(-1, input_ids.shape[-1])
input_ids = input_ids * cur_len_mask
start_idx = torch.maximum(cur_len + 1 - self.ngram_size, torch.tensor(self.ngram_size))
ngrams = input_ids.unfold(-1, self.ngram_size, 1)
last_tokens = poptorch.dynamic_slice(input_ids, 1, start_idx, self.ngram_size - 1, 1).unsqueeze(1)
last_tokens = (start_idx > self.ngram_size) * last_tokens
mask = torch.all(ngrams[..., : self.ngram_size - 1] == last_tokens, -1)
# If absolute_step + 1 < ngram_size set indices all to zero
mask = ~(cur_len + 1 < self.ngram_size) * mask
idx = torch.where(mask, ngrams[..., -1], -100)
val = (idx != -100) * torch.ones_like(idx) * VERY_LARGE_NEGATIVE_CONST
scores.scatter_add_(1, idx, val)
return scores
class IPUSuppressTokensLogitsProcessor(SuppressTokensLogitsProcessor):
@classmethod
def from_model(cls, inst, vocab_size: int):
self = inst
self.__class__ = cls
self.mask = torch.ones((1, vocab_size), dtype=torch.int32)
self.mask[:, self.suppress_tokens] = 0
return self
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, absolute_step: torch.IntTensor
) -> torch.FloatTensor:
mask = self.mask.to(scores.device)
return mask * scores + (1 - mask) * VERY_LARGE_NEGATIVE_CONST
class IPUSuppressTokensAtBeginLogitsProcessor(SuppressTokensAtBeginLogitsProcessor):
@classmethod
def from_model(cls, inst, vocab_size: int):
self = inst
self.__class__ = cls
self.mask = torch.ones((1, vocab_size), dtype=torch.int32)
self.mask[:, self.begin_suppress_tokens] = 0
return self
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, absolute_step: torch.IntTensor
) -> torch.FloatTensor:
mask = self.mask.to(scores.device)
cond = absolute_step != self.begin_index
mask = mask | cond
return mask * scores + (1 - mask) * VERY_LARGE_NEGATIVE_CONST
class IPUForceTokensLogitsProcessor(ForceTokensLogitsProcessor):
@classmethod
def from_model(cls, inst, vocab_size: int):
self = inst
self.__class__ = cls
self.force_token_map_keys = torch.tensor(list(self.force_token_map.keys()))
self.force_token_map_values = torch.tensor(list(self.force_token_map.values()))
return self
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, absolute_step: torch.IntTensor
) -> torch.FloatTensor:
mask = absolute_step == self.force_token_map_keys.to(scores.device)
selected_value = torch.amax(mask * self.force_token_map_values.to(scores.device))
value_mask = torch.arange(scores.shape[1], dtype=torch.long) == selected_value
cond = torch.any(mask).int()
scores = cond * torch.ones_like(scores) * VERY_LARGE_NEGATIVE_CONST + (1 - cond) * scores
scores -= cond * value_mask.unsqueeze(0) * VERY_LARGE_NEGATIVE_CONST
return scores
class IPUWhisperTimeStampLogitsProcessor(WhisperTimeStampLogitsProcessor):
@classmethod
def from_model(cls, inst, vocab_size: int):
self = inst
self.__class__ = cls
self.no_timestamps_mask = torch.ones((1, vocab_size), dtype=torch.int32)
self.no_timestamps_mask[:, self.no_timestamps_token_id] = 0
self.after_timestamp_begin_mask = torch.ones((1, vocab_size), dtype=torch.int32)
self.after_timestamp_begin_mask[:, self.timestamp_begin :] = 0
self.before_eos_mask = torch.ones((1, vocab_size), dtype=torch.int32)
self.before_eos_mask[:, : self.eos_token_id] = 0
self.last_allowed_mask = torch.ones((1, vocab_size), dtype=torch.int32)
if self.max_initial_timestamp_index is not None:
self.last_allowed_mask[:, self.timestamp_begin + self.max_initial_timestamp_index + 1 :] = 0
self.timestamp_begin_scores = VERY_LARGE_NEGATIVE_CONST * torch.ones((1, vocab_size), dtype=torch.int32)
self.timestamp_begin_scores[:, self.timestamp_begin] = 0
self.pre_timestamp_begin_mask = torch.ones((1, vocab_size), dtype=torch.int32)
self.pre_timestamp_begin_mask[:, : self.timestamp_begin] = 0
return self
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, absolute_step: torch.IntTensor
) -> torch.FloatTensor:
input_ids = input_ids.view(-1, input_ids.shape[-1])
no_timestamps_mask = self.no_timestamps_mask.to(scores.device)
scores = no_timestamps_mask * scores + (1 - no_timestamps_mask) * VERY_LARGE_NEGATIVE_CONST
cur_len = absolute_step
cond = cur_len == self.begin_index - 1
scores = ~cond * scores + cond * self.timestamp_begin_scores.to(device=scores.device)
timestamp_begin_not_forced = ~cond
last_was_timestamp = torch.index_select(input_ids, 1, cur_len - 1) >= self.timestamp_begin
last_was_timestamp &= (cur_len - self.begin_index) >= 1
penultimate_was_timestamp = (
torch.index_select(input_ids, 1, torch.where(cur_len > 1, cur_len - 2, cur_len - 1))
>= self.timestamp_begin
)
penultimate_was_timestamp |= (cur_len - self.begin_index) < 2
after_timestamp_begin_mask = self.after_timestamp_begin_mask.to(scores.device)
before_eos_mask = self.before_eos_mask.to(scores.device)
after_timestamp_begin_mask = after_timestamp_begin_mask | ~(
timestamp_begin_not_forced & last_was_timestamp & penultimate_was_timestamp
)
before_eos_mask = before_eos_mask | ~(
timestamp_begin_not_forced & last_was_timestamp & ~penultimate_was_timestamp
)
scores = after_timestamp_begin_mask * scores + (1 - after_timestamp_begin_mask) * VERY_LARGE_NEGATIVE_CONST
scores = before_eos_mask * scores + (1 - before_eos_mask) * VERY_LARGE_NEGATIVE_CONST
last_allowed_mask = self.last_allowed_mask.to(scores.device)
apply_max_initial_timestamp = cur_len == self.begin_index
last_allowed_mask = last_allowed_mask | ~(timestamp_begin_not_forced & apply_max_initial_timestamp)
scores = last_allowed_mask * scores + (1 - last_allowed_mask) * VERY_LARGE_NEGATIVE_CONST
log_probs = torch.nn.functional.log_softmax(scores, dim=-1)
timestamp_logprob = torch.logsumexp(log_probs[:, self.timestamp_begin :], dim=-1, keepdim=True)
max_text_token_logprob = torch.amax(log_probs[:, : self.timestamp_begin], dim=-1, keepdim=True)
pre_timestamp_begin_mask = self.pre_timestamp_begin_mask.to(scores.device)
pre_timestamp_begin_mask = pre_timestamp_begin_mask | ~(
timestamp_begin_not_forced & (timestamp_logprob > max_text_token_logprob)
)
scores = pre_timestamp_begin_mask * scores + (1 - pre_timestamp_begin_mask) * VERY_LARGE_NEGATIVE_CONST
return scores
IPULogitsProcessors = {
ForcedBOSTokenLogitsProcessor: IPUForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor: IPUForcedEOSTokenLogitsProcessor,
ForceTokensLogitsProcessor: IPUForceTokensLogitsProcessor,
MinLengthLogitsProcessor: IPUMinLengthLogitsProcessor,
NoRepeatNGramLogitsProcessor: IPUNoRepeatNGramLogitsProcessor,
SuppressTokensLogitsProcessor: IPUSuppressTokensLogitsProcessor,
SuppressTokensAtBeginLogitsProcessor: IPUSuppressTokensAtBeginLogitsProcessor,
WhisperTimeStampLogitsProcessor: IPUWhisperTimeStampLogitsProcessor,
}