optimum/habana/transformers/models/glm4v/modeling_chatglm.py (1,310 lines of code) (raw):
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
# Copyright (C) 2022-2024 Habana Labs, Ltd. an Intel Company
###############################################################################
"""PyTorch GLM-4V model."""
import math
import os
from typing import List, Optional, Tuple, Union
import habana_frameworks.torch.core as htcore
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
from torch.nn.utils import skip_init
from transformers.generation import GenerationMixin
from transformers.generation.logits_process import LogitsProcessor
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from optimum.habana.transformers.modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask
from .configuration_chatglm import GLM4VConfig
from .visual import EVA2CLIPModel
"""
Adapted from the following source:
https://huggingface.co/THUDM/glm-4v-9b/blob/main/modeling_chatglm.py
"""
try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
except ImportError:
print("Cannot import Fused SDPA from Habana Torch")
FusedSDPA = None
try:
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV3 as FusedRoPE
except ImportError:
print("Cannot import Fused Rope from Habana Torch")
FusedRoPE = None
try:
from habana_frameworks.torch.hpex.normalization import FusedRMSNorm
except ImportError:
print("Cannot import Fused RMSNorm from Habana Torch")
FusedRMSNorm = None
logger = logging.get_logger(__name__)
LANGUAGE_TOKEN_TYPE = 0
VISION_TOKEN_TYPE = 1
_CHECKPOINT_FOR_DOC = "THUDM/GLM4V"
_CONFIG_FOR_DOC = "GLM4VConfig"
# FusedScaledDotProductAttention
class ModuleFusedSDPA(torch.nn.Module):
def __init__(self, fusedSDPA):
super().__init__()
self._hpu_kernel_fsdpa = fusedSDPA
def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode):
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode)
class Matmul(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return torch.matmul(x, y)
def default_init(cls, *args, **kwargs):
return cls(*args, **kwargs)
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 198] = 5e4
return scores
class PrefixEncoder(torch.nn.Module):
"""
The torch.nn model to encode the prefix
Input shape: (batch-size, prefix-length)
Output shape: (batch-size, prefix-length, 2*layers*hidden)
"""
def __init__(self, config: GLM4VConfig):
super().__init__()
self.prefix_projection = config.prefix_projection
if self.prefix_projection:
# Use a two-layer MLP to encode the prefix
kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2
self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
self.trans = torch.nn.Sequential(
torch.nn.Linear(kv_size, config.hidden_size),
torch.nn.Tanh(),
torch.nn.Linear(config.hidden_size, kv_size),
)
else:
self.embedding = torch.nn.Embedding(
config.pre_seq_len, config.num_layers * config.kv_channels * config.multi_query_group_num * 2
)
def forward(self, prefix: torch.Tensor):
if self.prefix_projection:
prefix_tokens = self.embedding(prefix)
past_key_values = self.trans(prefix_tokens)
else:
past_key_values = self.embedding(prefix)
return past_key_values
def split_tensor_along_last_dim(
tensor: torch.Tensor,
num_partitions: int,
contiguous_split_chunks: bool = False,
) -> List[torch.Tensor]:
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
Returns:
A list of Tensors
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = tensor.size()[last_dim] // num_partitions
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, rope_ratio=1, original_impl=False, device=None, dtype=None):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
self.register_buffer("inv_freq", inv_freq)
self.dim = dim
self.original_impl = original_impl
self.rope_ratio = rope_ratio
self.seq_len_record = -1
self.cache = None
def impl(self, seq_length: int, dim: int, device: torch.device, dtype: torch.dtype):
if self.seq_len_record != seq_length:
self.seq_len_record = seq_length
base = 10000 * self.rope_ratio
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
seq = torch.arange(seq_length, device=inv_freq.device, dtype=torch.float32)
freqs = torch.outer(seq, inv_freq)
# first part even vector components, second part odd vector components,
# 2 * dim in dimension size
self.cache = torch.cat((freqs, freqs), dim=-1)
return self.cache
def forward_impl(self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000):
"""Enhanced Transformer with Rotary Position Embedding.
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
transformers/rope/__init__.py. MIT License:
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
"""
if self.seq_len_record != seq_len:
self.seq_len_record = seq_len
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
base = base * self.rope_ratio
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))
# Create position indexes `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)
# Calculate the product of position index and $\theta_i$
idx_theta = torch.outer(seq_idx, theta).float()
self.cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
# this is to mimic the behaviour of complex32, else we will get different results
if dtype in (torch.float16, torch.bfloat16, torch.int8):
self.cache = self.cache.bfloat16() if dtype == torch.bfloat16 else self.cache.half()
return self.cache
def forward(self, max_seq_len, offset=0):
if self.original_impl:
return self.forward_impl(max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device)
else:
return self.impl(max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device)
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
data_dtype = x.dtype
compute_dtype = rope_cache.dtype
if x.device.type == "hpu" and FusedRoPE is not None:
x_out = FusedRoPE.apply(x.to(compute_dtype), rope_cache)
else:
x = x.to(compute_dtype)
# x: [sq, b, np, hn]
sq, _, np, _ = x.size(0), x.size(1), x.size(2), x.size(3)
rot_dim = rope_cache.shape[-2] * 2
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
# truncate to support variable sizes
rope_cache = rope_cache[:sq]
xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
x_out2 = torch.stack(
[
xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
],
-1,
)
x_out2 = x_out2.flatten(3)
x_out = torch.cat((x_out2, x_pass), dim=-1)
return x_out.to(data_dtype)
class RMSNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(normalized_shape))
self.eps = eps
def forward(self, hidden_states: torch.Tensor):
if hidden_states.device.type == "hpu" and FusedRMSNorm is not None:
# mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype
if hidden_states.dtype != self.weight.dtype:
orig_dtype = hidden_states.dtype
hidden_states = FusedRMSNorm.apply(hidden_states.to(self.weight.dtype), self.weight, self.eps)
return hidden_states.to(orig_dtype)
else:
hidden_states = FusedRMSNorm.apply(hidden_states, self.weight, self.eps)
return hidden_states
else:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
return self.weight * hidden_states.to(input_dtype)
def gaudi_chatglm_repeat_kv(
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask: torch.Tensor,
):
"""
Refer https://github.com/huggingface/optimum-habana/blob/main/optimum/habana/transformers/models/llama/modeling_llama.py#L109
Copied from repeat_kv: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
The only differences are:
- Append num_key_value_heads == 1 check as kv states can be broadcasted during matmuls so need to expand and reshape them.
- Add new args query_states, key_states, value_states and attention_mask and update the logic for expansion.
The query states go from (batch, num_heads, seqlen, head_dim) to (batch, num_key_value_heads, n_rep, seqlen, head_dim)
The key/value states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_key_value_heads, 1, seqlen, head_dim)
"""
batch, num_query_heads, q_len, head_dim = query_layer.shape
batch, num_key_value_heads, kv_len, head_dim = key_layer.shape
n_rep = num_query_heads // num_key_value_heads
if n_rep == 1 or num_key_value_heads == 1:
return query_layer, key_layer, value_layer, attention_mask
new_kv_shape = (batch, num_key_value_heads, 1, kv_len, head_dim)
key_layer = key_layer.reshape(new_kv_shape)
value_layer = value_layer.reshape(new_kv_shape)
new_q_shape = (batch, num_key_value_heads, n_rep, q_len, head_dim)
query_layer = query_layer.reshape(new_q_shape)
if attention_mask is not None:
# Add groups dim and set to 1
attention_mask = attention_mask.unsqueeze(1)
return query_layer, key_layer, value_layer, attention_mask
class KVCache(torch.nn.Module):
def __init__(self):
super().__init__()
self.cache = None
self.inp_seq_len = -1
def allocate(self, inp_seq_len, dtype, device, shape):
if self.cache is None or self.cache.shape != shape:
self.inp_seq_len = inp_seq_len
self.cache = torch.zeros(shape, dtype=dtype, device=device)
else:
assert self.inp_seq_len == inp_seq_len, (
f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
)
self.cache.fill_(0)
def update(self, prev, cur, dim, idx, inp_seq_len):
orig_cur = cur
if prev.shape == cur.shape:
prev.copy_(cur)
return orig_cur
if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]:
# Initialize
prev[:, :, :inp_seq_len, :].copy_(cur)
return orig_cur
assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}"
if idx is not None:
prev.index_copy_(dim, idx - 1, cur)
return prev
else:
return torch.cat((prev, cur), dim=dim)
def get_shape(self):
if self.cache is None:
return None
return self.cache.shape
def forward(self, cur, dim, idx):
return self.update(self.cache, cur, dim, idx, self.inp_seq_len)
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class CoreAttention(torch.nn.Module):
def __init__(self, config: GLM4VConfig, layer_number):
super().__init__()
self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number)
projection_size = config.kv_channels * config.num_attention_heads
# Per attention head and per partition values.
self.hidden_size_per_partition = projection_size
self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
self.num_attention_heads_per_partition = config.num_attention_heads
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
coeff = self.layer_number
self.norm_factor *= coeff
self.coeff = coeff
self.dropout_rate = config.attention_dropout
self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
self.matmul_qk = Matmul()
self.matmul_av = Matmul()
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA is not None else None
self.q_block_size = 8192
def gaudi_flash_attn_v1(self, query_layer, key_layer, value_layer, attention_mask, softmax_mode):
"""
Gaudi version of Flash Attention V1 to support long sequence at prompt phase
Causal mask is not supported in this optimization
"""
q_len = query_layer.size(-2)
q_tiles = (
(q_len // self.q_block_size) if (q_len % self.q_block_size == 0) else math.ceil(q_len / self.q_block_size)
)
q_padding = q_tiles * self.q_block_size - q_len
query_layer = F.pad(query_layer, (0, 0, 0, q_padding), "constant", 0)
if attention_mask is not None:
attention_mask = F.pad(attention_mask, (0, 0, 0, q_padding), "constant", torch.finfo(key_layer.dtype).min)
row_o_list = []
for i in range(q_tiles):
s, e = i * self.q_block_size, (i + 1) * self.q_block_size
row_q = query_layer[:, :, s:e, :]
row_mask = attention_mask[:, :, s:e, :]
attn_output_partial = self.fused_scaled_dot_product_attention(
row_q, key_layer, value_layer, row_mask, self.dropout_rate, False, None, softmax_mode
)
row_o_list.append(attn_output_partial)
attn_output = torch.cat(row_o_list, dim=-2)
if q_padding != 0:
attn_output = attn_output[:, :, :-q_padding, :]
return attn_output
def forward(
self,
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask: torch.Tensor,
cache_position: Optional[torch.LongTensor] = None,
attn_softmax_bf16: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
**kwargs,
):
bsz, _, q_len, _ = query_layer.shape
if use_flash_attention and FusedSDPA is not None:
import habana_frameworks.torch.hpu as ht
softmax_mode = "fast" if flash_attention_fast_softmax else "None"
dropout_rate = 0.0
if self.training:
dropout_rate = self.dropout_rate
if q_len == 1:
# next token
use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
with ht.sdp_kernel(enable_recompute=use_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_layer, key_layer, value_layer, attention_mask, dropout_rate, False, None, softmax_mode
)
else:
# first token
if flash_attention_causal_mask:
# causal masking on first token requires inputs to be of the same length
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_layer, key_layer, value_layer, None, dropout_rate, True, None, softmax_mode
)
else:
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
# WA for long sequence, better perf. than recompute
if (q_len > 16384 or (q_len >= 6144 and bsz >= 2)) and self.training:
attn_output = self.gaudi_flash_attn_v1(
query_layer, key_layer, value_layer, attention_mask, dropout_rate, softmax_mode
)
else:
attn_output = self.fused_scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attention_mask,
dropout_rate,
False,
None,
softmax_mode,
)
else:
query_layer, key_layer, value_layer, attention_mask = gaudi_chatglm_repeat_kv(
query_layer, key_layer, value_layer, attention_mask
)
attn_weights = self.matmul_qk(query_layer, key_layer.transpose(-2, -1)) / self.norm_factor
if self.coeff is not None:
attn_weights = attn_weights * self.coeff
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask
if cache_position is not None:
causal_mask = attention_mask[:, :, cache_position, : key_layer.shape[-2]]
attn_weights = attn_weights + causal_mask
if attn_softmax_bf16:
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_layer.dtype)
else:
# upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query_layer.dtype
)
if self.training:
attn_weights = self.attention_dropout(attn_weights)
attn_output = self.matmul_av(attn_weights, value_layer)
attn_output = attn_output.reshape(bsz, -1, q_len, self.hidden_size_per_attention_head)
# =================
# Output. [sq, b, h]
# =================
attn_output = attn_output.permute(2, 0, 1, 3).contiguous()
context_layer = attn_output.reshape(q_len, bsz, self.hidden_size_per_partition)
return context_layer
CORE_ATTENTION_CLASSES = {"eager": CoreAttention, "sdpa": CoreAttention, "flash_attention_2": CoreAttention}
class SelfAttention(torch.nn.Module):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def __init__(self, config: GLM4VConfig, layer_number, device=None):
super().__init__()
self.config = config
self.layer_number = max(1, layer_number)
self.projection_size = config.kv_channels * config.num_attention_heads
# Per attention head and per partition values.
self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
self.num_attention_heads_per_partition = config.num_attention_heads
self.multi_query_attention = config.multi_query_attention
self.qkv_hidden_size = 3 * self.projection_size
self.original_rope = config.original_rope
if self.multi_query_attention:
self.num_multi_query_groups_per_partition = config.multi_query_group_num
self.qkv_hidden_size = (
self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
)
self.query_key_value = torch.nn.Linear(
config.hidden_size,
self.qkv_hidden_size,
bias=config.add_bias_linear or config.add_qkv_bias,
device=device,
**_config_to_kwargs(config),
)
self.core_attention = CoreAttention(config, self.layer_number)
# Output.
self.dense = torch.nn.Linear(
self.projection_size,
config.hidden_size,
bias=config.add_bias_linear,
device=device,
**_config_to_kwargs(config),
)
self.k_cache = KVCache()
self.v_cache = KVCache()
self.inp_seq_len = -1
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
cache_shape = (
batch_size,
self.num_multi_query_groups_per_partition,
max_seq_len,
self.hidden_size_per_attention_head,
)
device = self.query_key_value.weight.device
dtype = self.config.torch_dtype
self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape)
self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape)
def reorder(self, tensor, beam_idx, dim_a, dim_b):
updated = tensor.index_select(0, beam_idx)
tensor.copy_(updated)
def reorder_kv_cache(self, beam_idx: torch.LongTensor):
if self.k_cache.cache is None:
return (None, None)
head_dim = self.k_cache.cache.size(-1)
seq_length = self.k_cache.cache.size(-2)
self.reorder(self.k_cache.cache, beam_idx, seq_length, head_dim)
self.reorder(self.v_cache.cache, beam_idx, seq_length, head_dim)
return (self.k_cache.cache.shape, self.v_cache.cache.shape)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
prefix_encoder: Optional[torch.Tensor] = None,
rotary_pos_emb: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
token_idx: Optional[torch.Tensor] = None,
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
cache_idx: int = None,
**kwargs,
):
# hidden_states: [sq, b, h]
q_len, bsz, hiddenSize = hidden_states.size()
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
# =====================
# Query, Key, and Value
# =====================
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer = self.query_key_value(hidden_states)
if self.multi_query_attention:
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
[
self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
],
dim=-1,
)
query_layer = query_layer.view(
query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
)
key_layer = key_layer.view(
key_layer.size()[:-1]
+ (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
)
value_layer = value_layer.view(
value_layer.size()[:-1]
+ (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
)
else:
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
# apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None:
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
if prefix_encoder is not None:
prefix_encoder_key, prefix_encoder_value = prefix_encoder
if mixed_x_layer.dtype == torch.float8_e4m3fn:
from habana_frameworks.torch.hpex.kernels.Fp8Ops import cast_to_fp8_v2
prefix_encoder_key = cast_to_fp8_v2(prefix_encoder_key, None, False, False, mixed_x_layer.dtype)[0]
prefix_encoder_value = cast_to_fp8_v2(prefix_encoder_value, None, False, False, mixed_x_layer.dtype)[0]
else:
prefix_encoder_key = prefix_encoder_key.to(mixed_x_layer.dtype)
prefix_encoder_value = prefix_encoder_value.to(mixed_x_layer.dtype)
key_layer = torch.cat((prefix_encoder_key, key_layer), dim=0)
value_layer = torch.cat((prefix_encoder_value, value_layer), dim=0)
query_layer = query_layer.permute(1, 2, 0, 3).contiguous()
key_layer = key_layer.permute(1, 2, 0, 3).contiguous()
value_layer = value_layer.permute(1, 2, 0, 3).contiguous()
if use_cache:
# reuse k, v, self_attention
if reuse_cache:
key_layer = self.k_cache(key_layer, 2, token_idx)
value_layer = self.v_cache(value_layer, 2, token_idx)
past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape())
else:
if past_key_value is None:
past_key = torch.zeros(
key_layer.shape, dtype=self.query_key_value.weight.dtype, device=key_layer.device
)
past_value = torch.zeros(
key_layer.shape, dtype=self.query_key_value.weight.dtype, device=key_layer.device
)
past_key_value = [past_key, past_value]
key_layer = self.k_cache.update(past_key_value[0], key_layer, 2, token_idx, self.inp_seq_len)
value_layer = self.v_cache.update(past_key_value[1], value_layer, 2, token_idx, self.inp_seq_len)
if token_idx is None:
past_key_value = (key_layer, value_layer)
if cache_idx is not None and q_len == 1:
key_layer = key_layer[:, :, :cache_idx, :]
value_layer = value_layer[:, :, :cache_idx, :]
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, :cache_idx]
else:
past_key_value = None
# ==================================
# core attention computation
# ==================================
context_layer = self.core_attention(
query_layer,
key_layer,
value_layer,
attention_mask,
cache_position,
attn_softmax_bf16,
use_flash_attention,
flash_attention_recompute,
flash_attention_causal_mask,
flash_attention_fast_softmax,
**kwargs,
)
# =================
# Output. [sq, b, h]
# =================
output = self.dense(context_layer)
# No output_attention
attn_weights = None
return output, attn_weights, past_key_value
def _config_to_kwargs(args):
common_kwargs = {
"dtype": args.torch_dtype,
}
return common_kwargs
class MLP(torch.nn.Module):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def __init__(self, config: GLM4VConfig, device=None):
super().__init__()
self.add_bias = config.add_bias_linear
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
self.dense_h_to_4h = torch.nn.Linear(
config.hidden_size,
config.ffn_hidden_size * 2,
bias=self.add_bias,
device=device,
**_config_to_kwargs(config),
)
def swiglu(x):
x = torch.chunk(x, 2, dim=-1)
return F.silu(x[0]) * x[1]
self.activation_func = swiglu
# Project back to h.
self.dense_4h_to_h = torch.nn.Linear(
config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device, **_config_to_kwargs(config)
)
def forward(self, hidden_states):
# [s, b, 4hp]
intermediate_parallel = self.dense_h_to_4h(hidden_states)
intermediate_parallel = self.activation_func(intermediate_parallel)
# [s, b, h]
output = self.dense_4h_to_h(intermediate_parallel)
return output
class GLMBlock(torch.nn.Module):
"""A single transformer layer.
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""
def __init__(self, config: GLM4VConfig, layer_number, device=None):
super().__init__()
self.layer_number = layer_number
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
self.fp32_residual_connection = config.fp32_residual_connection
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
# Layernorm on the input data.
self.input_layernorm = LayerNormFunc(
config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
)
# Self attention.
self.self_attention = SelfAttention(config, layer_number, device=device)
self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output
self.post_attention_layernorm = LayerNormFunc(
config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
)
# MLP
self.mlp = MLP(config, device=device)
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
self.self_attention.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len)
def reorder_kv_cache(self, beam_idx: torch.LongTensor):
return self.self_attention.reorder_kv_cache(beam_idx)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
prefix_encoder: Optional[torch.Tensor] = None,
rotary_pos_emb: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
token_idx: Optional[torch.Tensor] = None,
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
cache_idx: int = None,
**kwargs,
):
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, self_attn_weights, present_key_value = self.self_attention(
layernorm_output,
attention_mask,
prefix_encoder,
rotary_pos_emb,
past_key_value,
output_attentions,
use_cache,
cache_position,
token_idx,
attn_softmax_bf16,
reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_fast_softmax=flash_attention_fast_softmax,
cache_idx=cache_idx,
**kwargs,
)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
layernorm_input = residual + layernorm_input
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
# MLP.
mlp_output = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
output = residual + output
outputs = (output,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
class GLMTransformer(torch.nn.Module):
"""Transformer class."""
def __init__(self, config: GLM4VConfig, device=None):
super().__init__()
self.fp32_residual_connection = config.fp32_residual_connection
self.post_layer_norm = config.post_layer_norm
# Number of layers.
self.num_layers = config.num_layers
# Transformer layers.
def build_layer(layer_number):
return GLMBlock(config, layer_number, device=device)
self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
if self.post_layer_norm:
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
# Final layer norm before output.
self.final_layernorm = LayerNormFunc(
config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
)
self.gradient_checkpointing = False
def _get_layer(self, layer_number):
return self.layers[layer_number]
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
for layer in self.layers:
layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len)
def reorder_kv_cache(self, beam_idx: torch.LongTensor):
return tuple(layer.reorder_kv_cache(beam_idx) for layer in self.layers)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
prefix_encoders: Optional[List[torch.FloatTensor]] = None,
rotary_pos_emb: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
token_idx: Optional[torch.Tensor] = None,
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
cache_idx: int = None,
lazy_mode: Optional[bool] = True,
):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
if lazy_mode:
htcore.mark_step()
for index in range(self.num_layers):
if (
lazy_mode
and not self.training
and (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1)
):
htcore.mark_step()
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
past_key_value = past_key_values[index] if past_key_values is not None else None
prefix_encoder = prefix_encoders[index] if prefix_encoders is not None else None
layer = self._get_layer(index)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(
*inputs,
None,
output_attentions,
use_cache,
cache_position,
None,
attn_softmax_bf16,
False,
use_flash_attention,
flash_attention_recompute,
flash_attention_causal_mask,
flash_attention_fast_softmax,
)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
hidden_states,
attention_mask,
prefix_encoder,
rotary_pos_emb,
)
else:
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
prefix_encoder=prefix_encoder,
rotary_pos_emb=rotary_pos_emb,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
token_idx=token_idx,
attn_softmax_bf16=attn_softmax_bf16,
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
cache_idx=cache_idx,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
next_cache = next_decoder_cache if use_cache else None
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# Final layer norm.
if self.post_layer_norm:
hidden_states = self.final_layernorm(hidden_states)
return hidden_states, next_cache, all_hidden_states, all_self_attns
class GLM4VPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
is_parallelizable = False
supports_gradient_checkpointing = True
config_class = GLM4VConfig
base_model_prefix = "transformer"
_no_split_modules = ["GLMBlock"]
def _init_weights(self, module: torch.nn.Module):
"""Initialize the weights."""
return
def get_masks(self, input_embeds, past_key_values, padding_mask=None):
batch_size, seq_length, embed_size = input_embeds.shape
full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_embeds.device)
full_attention_mask.tril_()
past_length = 0
if past_key_values:
past_length = past_key_values[0][0].shape[2]
if past_length:
full_attention_mask = torch.cat(
(torch.ones(batch_size, seq_length, past_length, device=input_embeds.device), full_attention_mask),
dim=-1,
)
if padding_mask is not None:
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
if not past_length and padding_mask is not None:
full_attention_mask -= padding_mask.unsqueeze(-1) - 1
full_attention_mask = (full_attention_mask < 0.5).bool()
full_attention_mask.unsqueeze_(1)
return full_attention_mask
def get_position_ids(self, input_ids, device):
batch_size, seq_length = input_ids.shape
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
return position_ids
def get_multimodal_position_ids(self, input_ids, device):
batch_size, seq_length = input_ids.shape
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
return position_ids
class Embedding(torch.nn.Module):
"""Language model embeddings."""
def __init__(self, config: GLM4VConfig, device=None):
super().__init__()
self.hidden_size = config.hidden_size
# Word embeddings (parallel).
self.word_embeddings = torch.nn.Embedding(
config.padded_vocab_size, self.hidden_size, dtype=config.torch_dtype, device=device
)
self.fp32_residual_connection = config.fp32_residual_connection
def forward(self, input_ids):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
embeddings = words_embeddings
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection:
embeddings = embeddings.float()
return embeddings
def is_empty(images_list: Optional[List[List[torch.Tensor]]]):
if images_list is None or len(images_list) == 0:
return True
for image_list in images_list:
if image_list is None:
raise ValueError("Image list contains some invalid contents (probably None)!")
return False
class GLM4VModel(GLM4VPreTrainedModel):
def __init__(self, config: GLM4VConfig, device=None, empty_init=True):
super().__init__(config)
if empty_init:
init_method = skip_init
else:
init_method = default_init
init_kwargs = {}
if device is not None:
init_kwargs["device"] = device
self.embedding = init_method(Embedding, config, **init_kwargs)
self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels
# Rotary positional embeddings
self.seq_length = config.seq_length
rotary_dim = (
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
)
self.rotary_pos_emb = RotaryEmbedding(
rotary_dim // 2,
rope_ratio=config.rope_ratio,
original_impl=config.original_rope,
device=device,
dtype=config.torch_dtype,
)
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
self.output_layer = init_method(
torch.nn.Linear,
config.hidden_size,
config.padded_vocab_size,
bias=False,
dtype=config.torch_dtype,
**init_kwargs,
)
self.pre_seq_len = config.pre_seq_len if config.pre_seq_len is not None else 0
self.prefix_projection = config.prefix_projection
if self.pre_seq_len > 0:
for param in self.parameters():
param.requires_grad = False
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
self.prefix_encoder = PrefixEncoder(config)
self.dropout = torch.nn.Dropout(0.1)
self.vision = EVA2CLIPModel(config)
if hasattr(config, "vision_config"):
self.image_size: int = self.config.vision_config["image_size"]
self.patch_size: int = self.config.vision_config["patch_size"]
self.num_patches = (self.image_size // self.patch_size // 2) ** 2
def get_input_embeddings(self):
return self.embedding.word_embeddings
def set_input_embeddings(self, value):
self.embedding.word_embeddings = value
def get_prompt(self, batch_size, device, dtype=torch.half):
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
past_key_values = past_key_values.view(
batch_size, self.pre_seq_len, self.num_layers * 2, self.multi_query_group_num, self.kv_channels
)
# seq_len, b, nh, hidden_size
past_key_values = self.dropout(past_key_values)
past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
return past_key_values
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length, pre_seq_len
):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length + pre_seq_len,
)
return combined_attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
if pre_seq_len > 0:
pre_seq_mask = torch.zeros(
[input_shape[0], 1, 1, pre_seq_len],
dtype=expanded_attn_mask.dtype,
device=expanded_attn_mask.device,
)
expanded_attn_mask = torch.cat([pre_seq_mask, expanded_attn_mask], dim=-1)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
self.encoder.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len)
def reorder_kv_cache(self, beam_idx: torch.LongTensor):
return self.encoder.reorder_kv_cache(beam_idx)
def forward(
self,
input_ids: torch.LongTensor = None,
images: torch.Tensor = None,
images_idx: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
token_idx: Optional[torch.Tensor] = None,
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
cache_idx: int = None,
lazy_mode: Optional[bool] = True,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""take care of image_encode, position_ids and (attention_mask = None is fine)"""
batch_size, seq_length = input_ids.shape
# generate mode with past_key_values. the image features are already mapped
if past_key_values is None:
# not allow for inputs_embeds, because we want to process image feature
assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
if not is_empty(images): # multi-modality
assert len(input_ids) == len(images), f"{len(input_ids)} {len(images)}"
# Please make sure to provide position_ids in inputs for Gaudi.
assert position_ids is not None
assert images_idx is not None
inputs_embeds = self.embedding(input_ids)
images = images.to(dtype=inputs_embeds.dtype)
images_features = self.vision(images)
if self.training and self.embedding.word_embeddings.weight.requires_grad:
inputs_embeds_list = []
for i in range(batch_size):
input_embeds_bs = torch.index_copy(inputs_embeds[i], 0, images_idx[i], images_features[i])
inputs_embeds_list.append(input_embeds_bs.unsqueeze(0))
inputs_embeds = torch.cat(inputs_embeds_list, dim=0)
else:
with torch.no_grad():
for i in range(batch_size):
inputs_embeds[i].index_copy_(0, images_idx[i], images_features[i])
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.embedding(input_ids)
prefix_encoders = None
if self.pre_seq_len > 0:
if token_idx is not None:
token_idx = token_idx + self.pre_seq_len
if past_key_values is None:
prefix_encoders = self.get_prompt(
batch_size=batch_size, device=input_ids.device, dtype=inputs_embeds.dtype
)
# Ptuning for multi-modality? This path is not verified for Gaudi
"""
if (attention_mask is not None and not attention_mask.all()) or (prefix_encoders and seq_length != 1):
if self.training:
for i in range(batch_size):
attention_mask[i].index_copy_(0, images_idx[i], torch.ones(self.num_patches, device=attention_mask.device, type=torch.int32))
input_ids[i].index_copy_(0, images_idx[i], input_ids[i, -1].repeat(self.num_patches) )
inputs_embeds = self.embedding(input_ids)
"""
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
if reuse_cache:
past_key_values_length = past_key_values[0][0][2]
else:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None and images is None:
position_ids = torch.arange(
past_key_values_length, seq_length_with_past, dtype=torch.long, device=inputs_embeds.device
)
position_ids = position_ids.unsqueeze(0)
if position_ids.size(-1) < seq_length:
position_ids = F.pad(position_ids, (0, seq_length - position_ids.size(-1)), "constant", 0)
cache_position = None
# Rotary positional embeddings
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
rotary_pos_emb = rotary_pos_emb[position_ids]
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
if self.pre_seq_len > 0:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length, self.pre_seq_len
)
else:
attention_mask = _gaudi_prepare_4d_causal_attention_mask(
attention_mask,
input_ids.shape if input_ids is not None else (batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
# Run encoder.
hidden_states, next_cache, all_hidden_states, all_self_attns = self.encoder(
inputs_embeds,
attention_mask,
prefix_encoders,
rotary_pos_emb,
past_key_values,
use_cache=use_cache,
cache_position=cache_position,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
token_idx=token_idx,
attn_softmax_bf16=attn_softmax_bf16,
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
cache_idx=cache_idx,
lazy_mode=lazy_mode,
)
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def _history_to_prompt(history, query):
prompt = ""
flag = False
for i, (old_query, response) in enumerate(history):
prompt += ("<|user|>" if flag else "") + old_query + "<|assistant|>" + response + "<|endoftext|>"
flag = True
prompt += "{}{}<|assistant|>".format("<|user|>" if flag else "", query)
return prompt
class GLM4VForConditionalGeneration(GLM4VPreTrainedModel, GenerationMixin):
def __init__(self, config: GLM4VConfig, empty_init=True, device=None):
super().__init__(config)
self.max_sequence_length = config.max_length
self.transformer = GLM4VModel(config, empty_init=empty_init, device=device)
self.config = config
if hasattr(config, "vision_config"):
self.image_size: int = self.config.vision_config["image_size"]
self.patch_size: int = self.config.vision_config["patch_size"]
self.num_patches = (self.image_size // self.patch_size // 2) ** 2
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
self.transformer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len)
self.kv_cache_len = max_seq_len
def reorder_kv_cache(self, beam_idx: torch.LongTensor):
return self.transformer.reorder_kv_cache(beam_idx)
def adjust_multimodal_inputs(self, inputs):
config = self.config
assert hasattr(config, "vision_config")
image_size: int = config.vision_config["image_size"]
patch_size: int = config.vision_config["patch_size"]
num_patches = (image_size // patch_size // 2) ** 2
input_ids = inputs["input_ids"]
position_ids = inputs["position_ids"]
attention_mask = inputs["attention_mask"]
images_idx = []
batch_size = len(input_ids)
for i in range(batch_size):
boi_token_pos, eoi_token_pos = (
input_ids[i].index(config.boi_token_id),
input_ids[i].index(config.eoi_token_id),
)
assert eoi_token_pos - boi_token_pos == 2
new_input_ids = (
input_ids[i][: boi_token_pos + 1] + [input_ids[i][-1]] * num_patches + input_ids[i][eoi_token_pos:]
)
new_position_ids = (
position_ids[i][: boi_token_pos + 1]
+ [position_ids[i][boi_token_pos + 1]] * num_patches
+ position_ids[i][eoi_token_pos:]
)
new_attention_mask = (
attention_mask[i][: boi_token_pos + 1] + [1] * num_patches + attention_mask[i][eoi_token_pos:]
)
new_image_idx = list(range(boi_token_pos, boi_token_pos + num_patches + 2))
input_ids[i] = new_input_ids
position_ids[i] = new_position_ids
attention_mask[i] = new_attention_mask
images_idx.append(new_image_idx)
inputs.data["images_idx"] = images_idx
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
images: Optional[torch.Tensor] = None,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: torch.Tensor = None,
inputs_embeds=None,
token_idx=None,
**kwargs,
) -> dict:
reuse_cache = kwargs.get("reuse_cache")
if past_key_values:
if token_idx is not None:
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
else:
input_ids = input_ids[:, -1:]
elif reuse_cache and token_idx is not None:
# With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass
input_ids = input_ids[:, :token_idx]
attention_mask = attention_mask[:, :token_idx]
assert position_ids is not None
images_idx = kwargs.get("images_idx")
if past_key_values:
position_ids = position_ids[..., -1:] + token_idx - position_ids.size(-1) - 1
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"images": images,
"images_idx": images_idx,
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"token_idx": token_idx,
"trim_logits": kwargs.get("trim_logits"),
"attn_softmax_bf16": kwargs.get("attn_softmax_bf16"),
"reuse_cache": reuse_cache,
"use_flash_attention": kwargs.get("use_flash_attention"),
"flash_attention_recompute": kwargs.get("flash_attention_recompute"),
"flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"),
"flash_attention_fast_softmax": kwargs.get("flash_attention_fast_softmax"),
"cache_idx": kwargs.get("cache_idx"),
"lazy_mode": kwargs.get("lazy_mode"),
}
)
return model_inputs
def forward(
self,
input_ids: torch.LongTensor = None,
images: torch.Tensor = None,
images_idx: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
token_idx: Optional[torch.Tensor] = None,
trim_logits: Optional[bool] = False,
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
cache_idx: int = None,
lazy_mode: Optional[bool] = True,
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.transformer(
input_ids=input_ids,
images=images,
images_idx=images_idx,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
token_idx=token_idx,
attn_softmax_bf16=attn_softmax_bf16,
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
cache_idx=cache_idx,
lazy_mode=lazy_mode,
)
hidden_states = outputs[0].transpose(0, 1).contiguous()
_, seq_len, _ = hidden_states.shape
if seq_len > 1 and trim_logits and not self.training:
if token_idx is not None:
hidden_states = hidden_states.index_select(1, token_idx - 1)
else:
hidden_states = hidden_states[:, -1, :]
lm_logits = self.transformer.output_layer(hidden_states).float()
loss = None
if labels is not None:
# This part should be done before sending into the model for Gaudi
"""
new_labels = []
for i in range(len(input_ids)):
input_id = input_ids[i].tolist()
boi_token_pos, eoi_token_pos = input_id.index(self.config.boi_token_id), input_id.index(
self.config.eoi_token_id)
assert eoi_token_pos - boi_token_pos == 2
new_labels.append(torch.cat(
(
labels[i, :boi_token_pos + 1],
torch.tensor([-100]).to(labels.device).to(labels.dtype).repeat(1600),
labels[i, eoi_token_pos:])))
labels = torch.stack(new_labels, dim=0)
"""
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
lm_logits = lm_logits.to(hidden_states.dtype)
loss = loss.to(hidden_states.dtype)
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@staticmethod
def _reorder_cache(
past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
Output shares the same memory storage as `past`.
"""
return tuple(
(
layer_past[0].index_select(0, beam_idx.to(layer_past[0].device)),
layer_past[1].index_select(0, beam_idx.to(layer_past[1].device)),
)
for layer_past in past
)
class GLM4VForSequenceClassification(GLM4VPreTrainedModel):
def __init__(self, config: GLM4VConfig, empty_init=True, device=None):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = GLM4VModel(config, empty_init=empty_init, device=device)
self.classifier_head = torch.nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
if config.classifier_dropout is not None:
self.dropout = torch.nn.Dropout(config.classifier_dropout)
else:
self.dropout = None
self.config = config
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
token_idx: Optional[torch.Tensor] = None,
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
cache_idx: int = None,
lazy_mode: Optional[bool] = True,
) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.transformer(
input_ids=input_ids,
images=None,
images_idx=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
token_idx=token_idx,
attn_softmax_bf16=attn_softmax_bf16,
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
cache_idx=cache_idx,
lazy_mode=lazy_mode,
)
hidden_states = outputs[0].transpose(0, 1).contiguous()
pooled_hidden_states = hidden_states[-1]
if self.dropout is not None:
pooled_hidden_states = self.dropout(pooled_hidden_states)
logits = self.classifier_head(pooled_hidden_states)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze().float(), labels.squeeze())
else:
loss = loss_fct(logits.float(), labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits.float(), labels.view(-1, self.num_labels))
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)