maga_transformer/models/sgpt_bloom.py (56 lines of code) (raw):
import torch
from maga_transformer.models.bloom import Bloom
from maga_transformer.models.base_model import GenerateOutput
from maga_transformer.distribute.worker_info import g_parallel_info
from maga_transformer.model_factory_register import register_model
class SGPTBloom(Bloom):
@torch.no_grad()
def generate_hidden_states_stream(self, input_token_ids: torch.IntTensor):
assert self.weight is not None, 'Please call load() first to initialize weights.'
input_token_ids_np = input_token_ids.cpu().numpy()
batch_size = len(input_token_ids_np)
eos_token_id = self.config.special_tokens.eos_token_id
input_lengths = torch.IntTensor([len(v[v != eos_token_id]) for v in input_token_ids_np])
input_token_ids = input_token_ids.type(torch.int32).to(self.device)
input_lengths = input_lengths.type(torch.int32).to(self.device)
max_input_length = input_token_ids.shape[-1]
gen_length = 1
beam_width = 1
max_seq_length = max_input_length + gen_length
memory_length = max_seq_length
device = self.device
# Since tril() doesn't support bf16 dtype, we create of bool type and then cast it to dtype.
attention_mask = torch.ones(
(max_input_length, max_input_length), dtype=torch.bool, device=device)\
.tril().unsqueeze(0)
attention_mask = attention_mask.tile(input_token_ids.shape[0], 1, 1).to(self.dtype)
for b, input_length in enumerate(input_lengths):
attention_mask[b, input_length:, ...] = 0
if g_parallel_info.is_pp_first:
# Prepare input tensors of decoder.
input_embeds = self.word_embedding(input_token_ids)
if self.position_encoding is not None:
position_ids = torch.arange(0, max_input_length, dtype=torch.int, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, max_input_length)
input_embeds += self.position_encoding(position_ids)
if self.pre_decoder_layernorm is not None:
input_embeds = self.pre_decoder_layernorm(input_embeds)
else:
# Dummy input_embeds
input_embeds = torch.empty(
size=(batch_size * beam_width, max_input_length, self.context_decoder.hidden_size),
dtype=self.context_decoder.dtype,
device=device)
hidden_states, _, _, _ = self.context_decoder.forward(
input_embeds=input_embeds,
attention_mask=attention_mask,
input_lengths=input_lengths,
memory_length=memory_length,
linear_bias_slopes=self.linear_bias_slopes)
hidden_states = self.post_decoder_layernorm(hidden_states) # type: ignore
yield GenerateOutput(hidden_states,
input_token_ids,
torch.ones_like(input_lengths).bool(),
[{}] * input_lengths.shape[0])
@torch.no_grad()
def generate_stream(self, # type: ignore
input_token_ids, input_lengths, generate_config):
return self.generate_hidden_states_stream(input_token_ids=input_token_ids)
register_model('sgpt_bloom', SGPTBloom)