modules/SwissArmyTransformer/sat/model/encoder_decoder_model.py (89 lines of code) (raw):
# -*- encoding: utf-8 -*-
'''
@File : encoder_decoder_model.py
@Time : 2021/11/22 23:35:28
@Author : Ming Ding
@Contact : dm18@mails.tsinghua.edu.cn
'''
# here put the import lib
import os
import sys
import math
import random
import torch
import argparse
from .base_model import BaseModel, BaseMixin, load_checkpoint, get_model
from sat.mpu.mappings import copy_to_model_parallel_region
from sat import update_args_with_file
from sat.resources import auto_create
class EncoderFinalMixin(BaseMixin):
def final_forward(self, logits, **kwargs):
logits = copy_to_model_parallel_region(logits)
return logits
class EncoderDecoderModel(torch.nn.Module):
def __init__(self, args, encoder=None, decoder=None, tie_word_embeddings=True, **kwargs):
super(EncoderDecoderModel, self).__init__()
if encoder is not None:
assert isinstance(encoder, BaseModel)
self.encoder = encoder
else:
self.encoder = BaseModel(args, **kwargs)
self.encoder.add_mixin("final", EncoderFinalMixin())
if decoder is not None:
assert isinstance(decoder, BaseModel)
self.decoder = decoder
else:
dec_args = argparse.Namespace(**vars(args))
dec_args.enc_hidden_size = dec_args.hidden_size # used for cross attn
override_attrs = ['num_layers', 'hidden_size', 'num_attention_heads', 'layernorm_order'
'max_sequence_length', 'inner_hidden_size', 'hidden_size_per_attention_head']
for name in override_attrs:
dec_attr = getattr(dec_args, 'dec_' + name, None)
if dec_attr is not None: # else use encoder-config
setattr(dec_args, name, dec_attr)
self.decoder = BaseModel(args, is_decoder=True, **kwargs)
self.tie_word_embeddings = tie_word_embeddings
if tie_word_embeddings:
self.decoder.transformer.word_embeddings = self.encoder.transformer.word_embeddings
def reinit(self, mixin_names): # please use different mixin names for encoder and decoder
self.encoder.reinit(mixin_names)
self.decoder.reinit(mixin_names)
def disable_untrainable_params(self):
self.encoder.disable_untrainable_params()
self.decoder.disable_untrainable_params()
def encode(self, input_ids, position_ids, attention_mask=None, **kw_args):
encoder_outputs, *_dumps = self.encoder(input_ids, position_ids, attention_mask, **kw_args)
return encoder_outputs
def decode(self, input_ids, position_ids, attention_mask, encoder_outputs,cross_attention_mask=None, **kw_args):
if attention_mask is None:
batch_size, seq_length = input_ids.size()[:2]
seq_ids = torch.arange(seq_length, device=input_ids.device)
attention_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
attention_mask = attention_mask.to(self.decoder.transformer.word_embeddings.weight.dtype)
attention_mask = attention_mask[:, None, :, :]
# If no context, please explicitly pass ``encoder_outputs=None''
return self.decoder(input_ids, position_ids, attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask, **kw_args)
def forward(self, enc_input_ids, enc_position_ids, dec_input_ids, dec_position_ids, *, enc_attention_mask=None, dec_attention_mask=None, cross_attention_mask=None, **kw_args):
# Please use self.decoder for auto-regressive generation.
batch_size, seq_length = enc_input_ids.size()[:2]
if enc_attention_mask is None:
enc_attention_mask = torch.ones(1, 1, 1, seq_length, dtype=self.encoder.transformer.word_embeddings.weight.dtype, device=enc_input_ids.device)
if cross_attention_mask is None:
cross_attention_mask = enc_attention_mask
encoder_outputs = self.encode(enc_input_ids, enc_position_ids, enc_attention_mask, **kw_args)
decoder_outputs, *mems = self.decode(dec_input_ids, dec_position_ids, dec_attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask, **kw_args)
return (encoder_outputs, decoder_outputs, *mems)
@classmethod
def add_model_specific_args(cls, parser):
group = parser.add_argument_group('EncoderDecoderModel', 'T5 or Bart')
group.add_argument("--dec-num-layers", type=int, default=None)
group.add_argument("--dec-hidden-size", type=int, default=None)
group.add_argument("--dec-num-attention-heads", type=int, default=None)
group.add_argument("--dec-max-sequence-length", type=int, default=None)
group.add_argument("--dec-inner-hidden-size", type=int, default=None)
group.add_argument("--dec-hidden-size-per-attention-head", type=int, default=None)
group.add_argument("--dec-layernorm-order", type=str, default=None)
return parser
@classmethod
def from_pretrained(cls, args, name, *, home_path=None, url=None): # TODO update model-only mode
model_path = auto_create(name, path=home_path, url=url)
args = update_args_with_file(args, path=os.path.join(model_path, 'model_config.json'))
model = get_model(args, cls)
load_checkpoint(model, args, load_path=model_path)
return model, args