modules/SwissArmyTransformer/sat/model/transformer.py (610 lines of code) (raw):
# coding=utf-8
# rewritten, Copyright (c) 2021, Ming Ding. All rights reserved.
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer."""
import math
import copy
import torch
import torch.nn.functional as F
from sat import mpu
from sat.mpu import get_model_parallel_world_size, ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, gather_from_model_parallel_region, copy_to_model_parallel_region, checkpoint
from sat.mpu.utils import divide, sqrt, scaled_init_method, unscaled_init_method, gelu
from sat.ops.layernorm import LayerNorm
from sat.transformer_defaults import HOOKS_DEFAULT, standard_attention, split_tensor_along_last_dim
class SelfAttention(torch.nn.Module):
def __init__(self, hidden_size, num_attention_heads,
attention_dropout_prob, output_dropout_prob,
init_method, layer_id, hidden_size_per_attention_head=None, output_layer_init_method=None, bias=True, qkv_bias=False, num_multi_query_heads=0, row_parallel_linear_final_bias=True,
hooks={}, transformer_pointer=None, params_dtype=torch.float, skip_init=False, device=torch.device('cpu')):
super(SelfAttention, self).__init__()
# Set output layer initialization if not provided.
if output_layer_init_method is None:
output_layer_init_method = init_method
self.hooks = hooks
self.layer_id = layer_id
# Per attention head and per partition values.
world_size = get_model_parallel_world_size()
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.num_multi_query_heads = num_multi_query_heads
if hidden_size_per_attention_head is None:
self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
else:
self.hidden_size_per_attention_head = hidden_size_per_attention_head
self.num_attention_heads_per_partition = divide(num_attention_heads, world_size)
self.num_multi_query_heads_per_partition = divide(num_multi_query_heads, world_size)
self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
self.hidden_size_per_partition = self.hidden_size_per_attention_head * self.num_attention_heads_per_partition
# Strided linear layer.
if num_multi_query_heads == 0:
qkv_size = 3 * self.inner_hidden_size
self.stride = 3
else: # multi-query
qkv_size = self.inner_hidden_size + self.hidden_size_per_attention_head * self.num_multi_query_heads * 2
self.stride = [self.num_attention_heads_per_partition, self.num_multi_query_heads_per_partition, self.num_multi_query_heads_per_partition]
self.query_key_value = ColumnParallelLinear(
hidden_size,
qkv_size,
stride=self.stride,
gather_output=False,
init_method=init_method,
bias=bias or qkv_bias,
params_dtype=params_dtype,
module=self,
name="query_key_value",
skip_init=skip_init,
device=device
)
self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
self.dense = RowParallelLinear(
self.inner_hidden_size,
hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
bias=bias,
params_dtype=params_dtype,
module=self,
name="dense",
skip_init=skip_init,
device=device,
final_bias=row_parallel_linear_final_bias
)
self.output_dropout = torch.nn.Dropout(output_dropout_prob)
object.__setattr__(self, 'transformer', transformer_pointer)
assert transformer_pointer is not None
def _transpose_for_scores(self, tensor):
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
size [b, np, s, hn].
"""
new_tensor_shape = tensor.size()[:-1] + \
(-1, # flexible for multi-query
self.hidden_size_per_attention_head)
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
def forward(self, hidden_states, mask, *args, **kw_args):
if 'attention_forward' in self.hooks:
return self.hooks['attention_forward'](hidden_states, mask, **kw_args)
else:
return HOOKS_DEFAULT['attention_forward'](self, hidden_states, mask, **kw_args)
def repartition(self):
world_size = get_model_parallel_world_size()
self.num_attention_heads_per_partition = divide(self.num_attention_heads, world_size)
self.hidden_size_per_partition = self.hidden_size_per_attention_head * self.num_attention_heads_per_partition
class CrossAttention(torch.nn.Module):
"""Parallel cross-attention layer for Transformer"""
def __init__(self, hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob, init_method,
layer_id, hidden_size_per_attention_head=None, output_layer_init_method=None, bias=True, cross_num_multi_query_heads=0, row_parallel_linear_final_bias=True, hooks={},
cross_attn_hidden_size=None, transformer_pointer=None, params_dtype=torch.float, skip_init=False, device=torch.device('cpu')):
super().__init__()
# Set output layer initialization if not provided.
if output_layer_init_method is None:
output_layer_init_method = init_method
self.hooks = hooks
self.layer_id = layer_id
self.num_attention_heads = num_attention_heads
self.hidden_size = hidden_size
# Per attention head and per partition values.
world_size = get_model_parallel_world_size()
if hidden_size_per_attention_head is None:
self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
else:
self.hidden_size_per_attention_head = hidden_size_per_attention_head
self.num_attention_heads_per_partition = divide(num_attention_heads, world_size)
self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
self.hidden_size_per_partition = self.hidden_size_per_attention_head * self.num_attention_heads_per_partition
self.cross_num_multi_query_heads = cross_num_multi_query_heads
# Strided linear layer.
if cross_num_multi_query_heads == 0:
kv_size = 2 * self.inner_hidden_size
else: # multi-query
kv_size = self.hidden_size_per_attention_head * self.cross_num_multi_query_heads * 2
self.query = ColumnParallelLinear(hidden_size, self.inner_hidden_size,
gather_output=False,
init_method=init_method, bias=bias, params_dtype=params_dtype, module=self, name="query", skip_init=skip_init, device=device)
if cross_attn_hidden_size is None:
cross_attn_hidden_size = hidden_size
self.cross_attn_hidden_size = cross_attn_hidden_size
self.key_value = ColumnParallelLinear(cross_attn_hidden_size, kv_size,
stride=2,
gather_output=False,
init_method=init_method, bias=bias, params_dtype=params_dtype, module=self, name="key_value",
skip_init=skip_init, device=device)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
# Output.
self.dense = RowParallelLinear(
self.inner_hidden_size,
hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method, bias=bias, params_dtype=params_dtype, module=self, name="dense",skip_init=skip_init,
device=device, final_bias=row_parallel_linear_final_bias)
self.output_dropout = torch.nn.Dropout(output_dropout_prob)
object.__setattr__(self, 'transformer', transformer_pointer)
assert transformer_pointer is not None
def _transpose_for_scores(self, tensor):
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
size [b, np, s, hn].
"""
new_tensor_shape = tensor.size()[:-1] + \
(-1, # flexible for multi-query
self.hidden_size_per_attention_head)
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
def forward(self, hidden_states, cross_attention_mask, encoder_outputs, **kw_args):
# hidden_states: [b, s, h]
if 'cross_attention_forward' in self.hooks:
return self.hooks['cross_attention_forward'](hidden_states, cross_attention_mask, encoder_outputs, **kw_args)
else:
return HOOKS_DEFAULT['cross_attention_forward'](self, hidden_states, cross_attention_mask, encoder_outputs, **kw_args)
def repartition(self):
world_size = get_model_parallel_world_size()
self.num_attention_heads_per_partition = divide(self.num_attention_heads, world_size)
self.hidden_size_per_partition = self.hidden_size_per_attention_head * self.num_attention_heads_per_partition
class MLP(torch.nn.Module):
def __init__(self, hidden_size, output_dropout_prob, init_method, inner_hidden_size=None,
output_layer_init_method=None, layer_id=None, row_parallel_linear_final_bias=True, hooks={}, bias=True, activation_func=gelu, transformer_pointer=None, is_gated_mlp=False, num_experts=1,
params_dtype=torch.float, skip_init=False, device=torch.device('cpu')):
super(MLP, self).__init__()
self.layer_id = layer_id
self.activation_func = activation_func
# Set output layer initialization if not provided.
if output_layer_init_method is None:
output_layer_init_method = init_method
self.hooks = hooks
# Project to 4h.
self.hidden_size = hidden_size
if inner_hidden_size is None:
inner_hidden_size = 4 * hidden_size
self.inner_hidden_size = inner_hidden_size
self.dense_h_to_4h = ColumnParallelLinear(
self.hidden_size,
self.inner_hidden_size,
gather_output=False,
init_method=init_method,
bias=bias,
params_dtype=params_dtype,
module=self,
name="dense_h_to_4h",
skip_init=skip_init,
device=device
)
# Project back to h.
self.dense_4h_to_h = RowParallelLinear(
self.inner_hidden_size,
self.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
bias=bias,
params_dtype=params_dtype,
module=self,
name="dense_4h_to_h",
skip_init=skip_init,
device=device,
final_bias=row_parallel_linear_final_bias
)
self.is_gated_mlp = is_gated_mlp
if is_gated_mlp:
self.dense_h_to_4h_gate = ColumnParallelLinear(
self.hidden_size,
self.inner_hidden_size,
gather_output=False,
init_method=init_method,
bias=False,
params_dtype=params_dtype,
module=self,
name="dense_h_to_4h_gate",
skip_init=skip_init,
device=device
)
self.num_experts = num_experts
for i in range(1, num_experts):
self.register_module(f"dense_h_to_4h_{i}", ColumnParallelLinear(
self.hidden_size,
self.inner_hidden_size,
gather_output=False,
init_method=init_method,
bias=bias,
params_dtype=params_dtype,
module=self,
name=f"dense_h_to_4h_{i}",
skip_init=skip_init,
device=device
))
# Project back to h.
self.register_module(f"dense_4h_to_h_{i}", RowParallelLinear(
self.inner_hidden_size,
self.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
bias=bias,
params_dtype=params_dtype,
module=self,
name=f"dense_4h_to_h_{i}",
skip_init=skip_init,
device=device,
final_bias=row_parallel_linear_final_bias
))
if is_gated_mlp:
self.register_module(f"dense_h_to_4h_gate_{i}", ColumnParallelLinear(
self.hidden_size,
self.inner_hidden_size,
gather_output=False,
init_method=init_method,
bias=False,
params_dtype=params_dtype,
module=self,
name=f"dense_h_to_4h_gate_{i}",
skip_init=skip_init,
device=device
))
self.dropout = torch.nn.Dropout(output_dropout_prob)
object.__setattr__(self, 'transformer', transformer_pointer)
assert transformer_pointer is not None
def forward(self, hidden_states, **kw_args):
if 'mlp_forward' in self.hooks:
output = self.hooks['mlp_forward'](hidden_states, **kw_args)
else:
output = HOOKS_DEFAULT['mlp_forward'](self, hidden_states, **kw_args)
if self.training:
output = self.dropout(output)
return output
class BaseTransformerLayer(torch.nn.Module):
def __init__(
self,
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
layernorm_epsilon,
init_method,
layer_id,
inner_hidden_size=None,
hidden_size_per_attention_head=None,
cross_hidden_size_per_attention_head=None,
output_layer_init_method=None,
layernorm_order='pre',
layernorm=LayerNorm,
is_decoder=False,
cross_attn_hidden_size=None,
use_bias=True,
use_qkv_bias=False,
num_multi_query_heads=0,
cross_num_multi_query_heads=0,
row_parallel_linear_final_bias=True,
drop_path=0,
activation_func=gelu,
is_gated_mlp=False,
num_experts=1,
hooks={},
transformer_pointer=None,
params_dtype=torch.float,
skip_init=False,
device=torch.device('cpu')
):
super(BaseTransformerLayer, self).__init__()
# Set output layer initialization if not provided.
if output_layer_init_method is None:
output_layer_init_method = init_method
self.layer_id = layer_id
self.is_decoder = is_decoder[layer_id] if type(is_decoder) is list else is_decoder
self.layernorm_order = layernorm_order
self.drop_path = drop_path
self.hooks = hooks
object.__setattr__(self, 'transformer', transformer_pointer)
assert transformer_pointer is not None
# Layernorm on the input data.
self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
# Self attention.
self.attention = SelfAttention(
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
init_method,
layer_id,
hidden_size_per_attention_head=hidden_size_per_attention_head,
output_layer_init_method=output_layer_init_method,
bias=use_bias,
qkv_bias=use_qkv_bias,
num_multi_query_heads=num_multi_query_heads,
row_parallel_linear_final_bias=row_parallel_linear_final_bias,
hooks=hooks,
transformer_pointer=transformer_pointer,
params_dtype=params_dtype,
skip_init=skip_init,
device=device
)
# Layernorm on the input data.
self.post_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
if self.layernorm_order == 'sandwich':
self.third_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
self.fourth_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
# Cross attention.
if self.is_decoder:
self.cross_attention = CrossAttention(
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
init_method,
layer_id,
hidden_size_per_attention_head=cross_hidden_size_per_attention_head,
output_layer_init_method=output_layer_init_method,
cross_attn_hidden_size=cross_attn_hidden_size,
bias=use_bias,
cross_num_multi_query_heads=cross_num_multi_query_heads,
row_parallel_linear_final_bias=row_parallel_linear_final_bias,
hooks=hooks,
transformer_pointer=transformer_pointer,
params_dtype=params_dtype
)
self.post_cross_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
# MLP
self.mlp = MLP(
hidden_size,
output_dropout_prob,
init_method,
inner_hidden_size=inner_hidden_size,
output_layer_init_method=output_layer_init_method,
bias=use_bias,
layer_id=layer_id,
activation_func=activation_func,
row_parallel_linear_final_bias=row_parallel_linear_final_bias,
hooks=hooks,
transformer_pointer=transformer_pointer,
is_gated_mlp=is_gated_mlp,
num_experts=num_experts,
params_dtype=params_dtype,
skip_init=skip_init,
device=device
)
def forward(self, hidden_states, mask, *args, **kw_args):
return HOOKS_DEFAULT['layer_forward'](self, hidden_states, mask, *args, **kw_args)
class BaseTransformer(torch.nn.Module):
def __init__(self,
num_layers,
vocab_size,
hidden_size,
num_attention_heads,
max_sequence_length,
embedding_dropout_prob=0,
attention_dropout_prob=0,
output_dropout_prob=0,
drop_path=0,
checkpoint_activations=False,
checkpoint_num_layers=1,
checkpoint_skip_layers=0,
layernorm_epsilon=1.0e-5,
init_method_std=0.02,
inner_hidden_size=None,
hidden_size_per_attention_head=None,
cross_hidden_size_per_attention_head=None,
layernorm_order='pre',
parallel_output=False,
is_decoder=False,
cross_attn_hidden_size=None,
use_bias=True,
use_qkv_bias=False,
num_multi_query_heads=0,
cross_num_multi_query_heads=0,
row_parallel_linear_final_bias=True,
activation_func=gelu,
is_gated_mlp=False,
is_rotary_emb=False,
num_experts=1,
layernorm=LayerNorm,
init_method=None,
use_final_layernorm=True,
hooks={},
params_dtype=torch.float,
skip_init=False,
device=torch.device('cpu')
):
super(BaseTransformer, self).__init__()
# recording parameters
self.hidden_size = hidden_size
self.inner_hidden_size = inner_hidden_size
self.hidden_size_per_attention_head = hidden_size_per_attention_head
self.cross_hidden_size_per_attention_head = cross_hidden_size_per_attention_head
self.is_decoder = is_decoder
self.cross_attn_hidden_size = cross_attn_hidden_size
self.cross_num_multi_query_heads = cross_num_multi_query_heads
if not is_decoder and cross_attn_hidden_size is not None:
print('warning: cross_attn_hidden_size is set but is_decoder is False')
self.use_bias = use_bias
self.use_qkv_bias = use_qkv_bias
self.num_multi_query_heads = num_multi_query_heads
self.is_gated_mlp = is_gated_mlp
self.is_rotary_emb = is_rotary_emb
self.num_experts = num_experts
self.use_final_layernorm = use_final_layernorm
self.layernorm_epsilon = layernorm_epsilon
self.parallel_output = parallel_output
self.checkpoint_activations = checkpoint_activations
self.checkpoint_num_layers = checkpoint_num_layers
self.checkpoint_skip_layers = checkpoint_skip_layers
assert checkpoint_skip_layers <= num_layers - checkpoint_num_layers, f'checkpoint_skip_layers too large. Please consider remove checkpoint_activations.'
self.max_sequence_length = max_sequence_length
self.layernorm_order = layernorm_order
self.row_parallel_linear_final_bias = row_parallel_linear_final_bias
self.hooks = copy.copy(hooks) # hooks will be updated each forward
object.__setattr__(self, 'transformer', self) # to give the default hooks the same api as outer hooks
# create embedding parameters
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
if vocab_size < 1000:
self.word_embeddings = torch.nn.Embedding(vocab_size, hidden_size, dtype=params_dtype, device=device)
torch.nn.init.normal_(self.word_embeddings.weight, mean=0.0, std=init_method_std)
else:
self.word_embeddings = VocabParallelEmbedding(
num_embeddings=vocab_size, embedding_dim=hidden_size,
params_dtype=params_dtype, skip_init=skip_init, device=device)
if self.is_rotary_emb:
from sat.model.position_embedding.triton_rotary_embeddings import FastRotaryEmbedding
self.position_embeddings = FastRotaryEmbedding(hidden_size // num_attention_heads)
else:
self.position_embeddings = torch.nn.Embedding(max_sequence_length, hidden_size)
torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
# create all layers
if init_method is None:
self.output_layer_init_method = scaled_init_method(init_method_std, num_layers)
self.init_method = unscaled_init_method(init_method_std)
else:
self.output_layer_init_method = init_method
self.init_method = init_method
def get_layer(layer_id):
return BaseTransformerLayer(
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
layernorm_epsilon,
self.init_method,
layer_id,
inner_hidden_size=inner_hidden_size,
hidden_size_per_attention_head=hidden_size_per_attention_head,
cross_hidden_size_per_attention_head=cross_hidden_size_per_attention_head,
output_layer_init_method=self.output_layer_init_method,
is_decoder=self.is_decoder,
cross_attn_hidden_size=cross_attn_hidden_size,
layernorm_order=layernorm_order,
layernorm=layernorm,
use_bias=use_bias,
use_qkv_bias=use_qkv_bias,
num_multi_query_heads=num_multi_query_heads,
cross_num_multi_query_heads=cross_num_multi_query_heads,
row_parallel_linear_final_bias=row_parallel_linear_final_bias,
drop_path=drop_path,
activation_func=activation_func,
is_gated_mlp=is_gated_mlp,
num_experts=num_experts,
hooks=self.hooks,
transformer_pointer=self,
params_dtype=params_dtype,
skip_init=skip_init,
device=device
)
self.layers = torch.nn.ModuleList(
[get_layer(layer_id) for layer_id in range(num_layers)])
# Final layer norm before output.
if use_final_layernorm:
self.final_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
def forward(self, input_ids, position_ids, attention_mask, *,
output_hidden_states=False, **kw_args):
# sanity check
assert len(input_ids.shape) >= 2
batch_size, query_length = input_ids.shape[:2]
if attention_mask is None:
# Definition: None means full attention
attention_mask = torch.ones(1, 1, device=input_ids.device)
elif isinstance(attention_mask, int) and (attention_mask < 0):
# Definition: -1 means lower triangular attention mask
attention_mask = torch.ones(query_length, query_length,
device=input_ids.device).tril()
attention_mask = attention_mask.type_as(
next(self.parameters())
)
assert len(attention_mask.shape) == 2 or \
len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1
# initial output_cross_layer might be generated by word/position_embedding_forward
output_cross_layer = {}
# embedding part
if 'word_embedding_forward' in self.hooks:
hidden_states = self.hooks['word_embedding_forward'](input_ids, output_cross_layer=output_cross_layer, **kw_args)
else: # default
hidden_states = HOOKS_DEFAULT['word_embedding_forward'](self, input_ids, output_cross_layer=output_cross_layer,**kw_args)
# handle position embedding
if 'position_embedding_forward' in self.hooks:
position_embeddings = self.hooks['position_embedding_forward'](position_ids, output_cross_layer=output_cross_layer, **kw_args)
else:
assert len(position_ids.shape) <= 2
assert position_ids.shape[-1] == hidden_states.shape[1], (position_ids.shape, hidden_states.shape)
position_embeddings = HOOKS_DEFAULT['position_embedding_forward'](self, position_ids, output_cross_layer=output_cross_layer, **kw_args)
if position_embeddings is not None:
hidden_states = hidden_states + position_embeddings
hidden_states = self.embedding_dropout(hidden_states)
output_per_layers = []
if self.checkpoint_activations:
# define custom_forward for checkpointing
def custom(start, end, kw_args_index, cross_layer_index):
def custom_forward(*inputs):
layers_ = self.layers[start:end]
x_, mask = inputs[0], inputs[1]
# recover kw_args and output_cross_layer
flat_inputs = inputs[2:]
kw_args, output_cross_layer = {}, {}
for k, idx in kw_args_index.items():
kw_args[k] = flat_inputs[idx]
for k, idx in cross_layer_index.items():
output_cross_layer[k] = flat_inputs[idx]
# -----------------
output_per_layers_part = []
for i, layer in enumerate(layers_):
output_this_layer_obj, output_cross_layer_obj = {}, {}
if 'layer_forward' in self.hooks:
layer_ret = self.hooks['layer_forward'](
x_, mask, layer_id=layer.layer_id,
**kw_args, position_ids=position_ids, **output_cross_layer,
output_this_layer=output_this_layer_obj,
output_cross_layer=output_cross_layer_obj
)
else:
layer_ret = layer(
x_, mask, layer_id=layer.layer_id,
**kw_args, position_ids=position_ids, **output_cross_layer,
output_this_layer=output_this_layer_obj,
output_cross_layer=output_cross_layer_obj
)
if isinstance(layer_ret, tuple):
layer_ret = layer_ret[0] # for legacy API
x_, output_this_layer, output_cross_layer = layer_ret, output_this_layer_obj, output_cross_layer_obj
if output_hidden_states:
output_this_layer['hidden_states'] = x_
output_per_layers_part.append(output_this_layer)
# flatten for re-aggregate keywords outputs
flat_outputs = []
for output_this_layer in output_per_layers_part:
for k in output_this_layer:
# TODO add warning for depth>=2 grad tensors
flat_outputs.append(output_this_layer[k])
output_this_layer[k] = len(flat_outputs) - 1
for k in output_cross_layer:
flat_outputs.append(output_cross_layer[k])
output_cross_layer[k] = len(flat_outputs) - 1
# --------------------
return (x_, output_per_layers_part, output_cross_layer, *flat_outputs)
return custom_forward
# prevent to lose requires_grad in checkpointing.
# To save memory when only finetuning the final layers, don't use checkpointing.
if self.training:
hidden_states.requires_grad_(True)
l, num_layers = 0, len(self.layers)
chunk_length = self.checkpoint_num_layers
output_this_layer = []
while l < num_layers:
args = [hidden_states, attention_mask]
# flatten kw_args and output_cross_layer
flat_inputs, kw_args_index, cross_layer_index = [], {}, {}
for k, v in kw_args.items():
flat_inputs.append(v)
kw_args_index[k] = len(flat_inputs) - 1
for k, v in output_cross_layer.items():
flat_inputs.append(v)
cross_layer_index[k] = len(flat_inputs) - 1
# --------------------
if l + self.checkpoint_skip_layers >= num_layers:
# no checkpointing
hidden_states, output_per_layers_part, output_cross_layer, *flat_outputs = \
custom(l, l + chunk_length, kw_args_index, cross_layer_index)(*args, *flat_inputs)
else:
hidden_states, output_per_layers_part, output_cross_layer, *flat_outputs = \
checkpoint(custom(l, l + chunk_length, kw_args_index, cross_layer_index), *args, *flat_inputs)
# recover output_per_layers_part, output_cross_layer
for output_this_layer in output_per_layers_part:
for k in output_this_layer:
output_this_layer[k] = flat_outputs[output_this_layer[k]]
for k in output_cross_layer:
output_cross_layer[k] = flat_outputs[output_cross_layer[k]]
# --------------------
output_per_layers.extend(output_per_layers_part)
l += chunk_length
else:
output_this_layer = []
for i, layer in enumerate(self.layers):
args = [hidden_states, attention_mask]
output_this_layer_obj, output_cross_layer_obj = {}, {}
if 'layer_forward' in self.hooks: # customized layer_forward
layer_ret = self.hooks['layer_forward'](*args,
layer_id=torch.tensor(i),
**kw_args,
position_ids=position_ids,
**output_cross_layer,
output_this_layer=output_this_layer_obj, output_cross_layer=output_cross_layer_obj
)
else:
layer_ret = layer(*args, layer_id=torch.tensor(i), **kw_args, position_ids=position_ids, **output_cross_layer,
output_this_layer=output_this_layer_obj, output_cross_layer=output_cross_layer_obj)
if isinstance(layer_ret, tuple):
layer_ret = layer_ret[0] # for legacy API
hidden_states, output_this_layer, output_cross_layer = layer_ret, output_this_layer_obj, output_cross_layer_obj
if output_hidden_states:
output_this_layer['hidden_states'] = hidden_states
output_per_layers.append(output_this_layer)
# Final layer norm.
if self.use_final_layernorm:
logits = self.final_layernorm(hidden_states)
else:
logits = hidden_states
logits = copy_to_model_parallel_region(logits)
if 'final_forward' in self.hooks:
logits_parallel = self.hooks['final_forward'](logits, **kw_args, parallel_output=self.parallel_output)
else:
logits_parallel = HOOKS_DEFAULT['final_forward'](self, logits, **kw_args, parallel_output=self.parallel_output)
outputs = [logits_parallel]
outputs.extend(output_per_layers)
return outputs