maga_transformer/models/sgpt_bloom_vector.py (61 lines of code) (raw):

import torch import numpy as np from typing import Any, Dict, List from maga_transformer.models.sgpt_bloom import SGPTBloom from maga_transformer.models.base_model import GenerateOutput from maga_transformer.model_factory_register import register_model class SGPTBloomVector(SGPTBloom): @torch.no_grad() def generate_weighted_hidden_states_stream(self, input_token_ids: torch.IntTensor): eos_token_id = self.config.special_tokens.eos_token_id batch_size = input_token_ids.shape[0] input_mask = torch.where(input_token_ids != eos_token_id, 1, 0) gen_output = list(self.generate_hidden_states_stream(input_token_ids))[0] hidden_states = gen_output.hidden_states weights = ( torch.arange(start=1, end=hidden_states.shape[1] + 1) .unsqueeze(0) .unsqueeze(-1) .expand(hidden_states.size()) .float().to(hidden_states.device) ) # input_mask_expanded.shape = [bs, seq_len, hid_dim] # input_mask.shape = [batch, len] # input_mask_expanded.shape = [batch, len, feat] input_mask_expanded = ( input_mask .unsqueeze(-1) .expand(hidden_states.size()) .float() ).to(hidden_states.device) # Perform weighted mean pooling across seq_len: bs, seq_len, hidden_dim -> bs, hidden_dim sum_embeddings = torch.sum(hidden_states * input_mask_expanded * weights, dim=1) sum_mask = torch.sum(input_mask_expanded * weights, dim=1) embeddings = sum_embeddings / sum_mask embeddings = embeddings.cpu() norm = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) yield GenerateOutput(norm, input_token_ids.unsqueeze(1), # add beam dim torch.ones(batch_size), [{"decimals": 6}] * batch_size) @staticmethod def process_encode_plugin(prompt: str, generate_config: Dict[str, Any], tokenizer: Any, max_seq_len: int, **kwargs: Any) -> List[int]: custon_gen_cfg = generate_config["custom_prop"] is_query = custon_gen_cfg.get("is_query", False) case_sensitive = custon_gen_cfg.get("case_sensitive", False) prompt = prompt if case_sensitive else prompt.lower() tokenizer = tokenizer.tokenizer # PreTrainedTokenizerFast batch_tokens = tokenizer(prompt, padding=False, truncation=True, max_length=max_seq_len - 2) input_ids = batch_tokens["input_ids"] if is_query: SPECB_QUE_BOS = tokenizer.encode("[", add_special_tokens=False)[0] SPECB_QUE_EOS = tokenizer.encode("]", add_special_tokens=False)[0] input_ids.insert(0, SPECB_QUE_BOS) input_ids.append(SPECB_QUE_EOS) else: SPECB_DOC_BOS = tokenizer.encode("{", add_special_tokens=False)[0] SPECB_DOC_EOS = tokenizer.encode("}", add_special_tokens=False)[0] input_ids.insert(0, SPECB_DOC_BOS) input_ids.append(SPECB_DOC_EOS) return input_ids @torch.no_grad() def generate_stream(self, input_token_ids, input_lengths, generate_config): return self.generate_weighted_hidden_states_stream(input_token_ids=input_token_ids) register_model('sgpt_bloom_vector', SGPTBloomVector)