arctic_inference/vllm/spec_dec/arctic_speculator.py [295:321]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
                                          head_index)
            previous_hidden_states = states
            states = states.flatten(0, 1)
            head_weight = (self.qhead[head_index] if self.qhead is not None
                           and batch_size <= 32 else self.head[head_index])
            logits = self.logits_processor(head_weight, states)

            if self.tp_size == 1:
                last_tokens = torch.argmax(logits,
                                           dim=-1).reshape(batch_size, -1)
            else:
                vals, indices = torch.topk(logits, 1, dim=-1)
                indices = indices + self.tp_rank * logits.shape[-1]

                packed_data = torch.cat(
                    [vals.to(torch.float64).view(torch.int64), indices], dim=0)
                packed_data = self.TP_GROUP.all_gather(packed_data)
                vals, indices = packed_data.split(batch_size, dim=0)
                vals = vals.view(torch.float64)

                argidx = torch.argmax(vals, -1).reshape(batch_size, -1)
                last_tokens = torch.gather(indices, -1, argidx)

            if next_tokens_tensors[head_index] == None:
                next_tokens_tensors[head_index] = last_tokens
            else:
                next_tokens_tensors[head_index].copy_(last_tokens)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



arctic_inference/vllm/spec_dec/arctic_speculator.py [723:749]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
                                              head_index)
            previous_hidden_states = states
            states = states.flatten(0, 1)
            head_weight = (self.qhead[head_index] if self.qhead is not None
                           and batch_size <= 32 else self.head[head_index])
            logits = self.logits_processor(head_weight, states)

            if self.tp_size == 1:
                last_tokens = torch.argmax(logits,
                                           dim=-1).reshape(batch_size, -1)
            else:
                vals, indices = torch.topk(logits, 1, dim=-1)
                indices = indices + self.tp_rank * logits.shape[-1]

                packed_data = torch.cat(
                    [vals.to(torch.float64).view(torch.int64), indices], dim=0)
                packed_data = self.TP_GROUP.all_gather(packed_data)
                vals, indices = packed_data.split(batch_size, dim=0)
                vals = vals.view(torch.float64)

                argidx = torch.argmax(vals, -1).reshape(batch_size, -1)
                last_tokens = torch.gather(indices, -1, argidx)

            if next_tokens_tensors[head_index] == None:
                next_tokens_tensors[head_index] = last_tokens
            else:
                next_tokens_tensors[head_index].copy_(last_tokens)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



