in parler_tts/logits_processors.py [0:0]
def __init__(self, eos_token_id, num_codebooks: int, batch_size: int, device: str = "cpu"):
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id, device=device)
self.eos_token_id = eos_token_id
self.batch_size = batch_size
if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
self.num_codebooks = num_codebooks
self.device = device
self.codebook_idx = torch.arange(self.batch_size*self.num_codebooks, device=self.device)
self.first_codebooks_unfinished = torch.arange(batch_size, device=device)*num_codebooks
max_codebooks = torch.arange(self.batch_size, device=self.device)*self.num_codebooks + self.num_codebooks -1
self.max_codebooks = max_codebooks