in modules/SwissArmyTransformer/sat/model/transformer.py [0:0]
def forward(self, input_ids, position_ids, attention_mask, *,
output_hidden_states=False, **kw_args):
# sanity check
assert len(input_ids.shape) >= 2
batch_size, query_length = input_ids.shape[:2]
if attention_mask is None:
# Definition: None means full attention
attention_mask = torch.ones(1, 1, device=input_ids.device)
elif isinstance(attention_mask, int) and (attention_mask < 0):
# Definition: -1 means lower triangular attention mask
attention_mask = torch.ones(query_length, query_length,
device=input_ids.device).tril()
attention_mask = attention_mask.type_as(
next(self.parameters())
)
assert len(attention_mask.shape) == 2 or \
len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1
# initial output_cross_layer might be generated by word/position_embedding_forward
output_cross_layer = {}
# embedding part
if 'word_embedding_forward' in self.hooks:
hidden_states = self.hooks['word_embedding_forward'](input_ids, output_cross_layer=output_cross_layer, **kw_args)
else: # default
hidden_states = HOOKS_DEFAULT['word_embedding_forward'](self, input_ids, output_cross_layer=output_cross_layer,**kw_args)
# handle position embedding
if 'position_embedding_forward' in self.hooks:
position_embeddings = self.hooks['position_embedding_forward'](position_ids, output_cross_layer=output_cross_layer, **kw_args)
else:
assert len(position_ids.shape) <= 2
assert position_ids.shape[-1] == hidden_states.shape[1], (position_ids.shape, hidden_states.shape)
position_embeddings = HOOKS_DEFAULT['position_embedding_forward'](self, position_ids, output_cross_layer=output_cross_layer, **kw_args)
if position_embeddings is not None:
hidden_states = hidden_states + position_embeddings
hidden_states = self.embedding_dropout(hidden_states)
output_per_layers = []
if self.checkpoint_activations:
# define custom_forward for checkpointing
def custom(start, end, kw_args_index, cross_layer_index):
def custom_forward(*inputs):
layers_ = self.layers[start:end]
x_, mask = inputs[0], inputs[1]
# recover kw_args and output_cross_layer
flat_inputs = inputs[2:]
kw_args, output_cross_layer = {}, {}
for k, idx in kw_args_index.items():
kw_args[k] = flat_inputs[idx]
for k, idx in cross_layer_index.items():
output_cross_layer[k] = flat_inputs[idx]
# -----------------
output_per_layers_part = []
for i, layer in enumerate(layers_):
output_this_layer_obj, output_cross_layer_obj = {}, {}
if 'layer_forward' in self.hooks:
layer_ret = self.hooks['layer_forward'](
x_, mask, layer_id=layer.layer_id,
**kw_args, position_ids=position_ids, **output_cross_layer,
output_this_layer=output_this_layer_obj,
output_cross_layer=output_cross_layer_obj
)
else:
layer_ret = layer(
x_, mask, layer_id=layer.layer_id,
**kw_args, position_ids=position_ids, **output_cross_layer,
output_this_layer=output_this_layer_obj,
output_cross_layer=output_cross_layer_obj
)
if isinstance(layer_ret, tuple):
layer_ret = layer_ret[0] # for legacy API
x_, output_this_layer, output_cross_layer = layer_ret, output_this_layer_obj, output_cross_layer_obj
if output_hidden_states:
output_this_layer['hidden_states'] = x_
output_per_layers_part.append(output_this_layer)
# flatten for re-aggregate keywords outputs
flat_outputs = []
for output_this_layer in output_per_layers_part:
for k in output_this_layer:
# TODO add warning for depth>=2 grad tensors
flat_outputs.append(output_this_layer[k])
output_this_layer[k] = len(flat_outputs) - 1
for k in output_cross_layer:
flat_outputs.append(output_cross_layer[k])
output_cross_layer[k] = len(flat_outputs) - 1
# --------------------
return (x_, output_per_layers_part, output_cross_layer, *flat_outputs)
return custom_forward
# prevent to lose requires_grad in checkpointing.
# To save memory when only finetuning the final layers, don't use checkpointing.
if self.training:
hidden_states.requires_grad_(True)
l, num_layers = 0, len(self.layers)
chunk_length = self.checkpoint_num_layers
output_this_layer = []
while l < num_layers:
args = [hidden_states, attention_mask]
# flatten kw_args and output_cross_layer
flat_inputs, kw_args_index, cross_layer_index = [], {}, {}
for k, v in kw_args.items():
flat_inputs.append(v)
kw_args_index[k] = len(flat_inputs) - 1
for k, v in output_cross_layer.items():
flat_inputs.append(v)
cross_layer_index[k] = len(flat_inputs) - 1
# --------------------
if l + self.checkpoint_skip_layers >= num_layers:
# no checkpointing
hidden_states, output_per_layers_part, output_cross_layer, *flat_outputs = \
custom(l, l + chunk_length, kw_args_index, cross_layer_index)(*args, *flat_inputs)
else:
hidden_states, output_per_layers_part, output_cross_layer, *flat_outputs = \
checkpoint(custom(l, l + chunk_length, kw_args_index, cross_layer_index), *args, *flat_inputs)
# recover output_per_layers_part, output_cross_layer
for output_this_layer in output_per_layers_part:
for k in output_this_layer:
output_this_layer[k] = flat_outputs[output_this_layer[k]]
for k in output_cross_layer:
output_cross_layer[k] = flat_outputs[output_cross_layer[k]]
# --------------------
output_per_layers.extend(output_per_layers_part)
l += chunk_length
else:
output_this_layer = []
for i, layer in enumerate(self.layers):
args = [hidden_states, attention_mask]
output_this_layer_obj, output_cross_layer_obj = {}, {}
if 'layer_forward' in self.hooks: # customized layer_forward
layer_ret = self.hooks['layer_forward'](*args,
layer_id=torch.tensor(i),
**kw_args,
position_ids=position_ids,
**output_cross_layer,
output_this_layer=output_this_layer_obj, output_cross_layer=output_cross_layer_obj
)
else:
layer_ret = layer(*args, layer_id=torch.tensor(i), **kw_args, position_ids=position_ids, **output_cross_layer,
output_this_layer=output_this_layer_obj, output_cross_layer=output_cross_layer_obj)
if isinstance(layer_ret, tuple):
layer_ret = layer_ret[0] # for legacy API
hidden_states, output_this_layer, output_cross_layer = layer_ret, output_this_layer_obj, output_cross_layer_obj
if output_hidden_states:
output_this_layer['hidden_states'] = hidden_states
output_per_layers.append(output_this_layer)
# Final layer norm.
if self.use_final_layernorm:
logits = self.final_layernorm(hidden_states)
else:
logits = hidden_states
logits = copy_to_model_parallel_region(logits)
if 'final_forward' in self.hooks:
logits_parallel = self.hooks['final_forward'](logits, **kw_args, parallel_output=self.parallel_output)
else:
logits_parallel = HOOKS_DEFAULT['final_forward'](self, logits, **kw_args, parallel_output=self.parallel_output)
outputs = [logits_parallel]
outputs.extend(output_per_layers)
return outputs