in code/src/model/transformer.py [0:0]
def __init__(self, args, encoder):
super().__init__()
self.attributes = args.attributes
self.attr_values = args.attr_values
self.n_words = args.n_words
self.emb_dim = args.decoder_embed_dim
self.dropout = args.dropout
self.share_encdec_emb = args.share_encdec_emb
self.share_decpro_emb = args.share_decpro_emb
self.freeze_dec_emb = args.freeze_dec_emb
self.encoder_class = encoder.__class__
self.beam_size = args.beam_size
self.length_penalty = args.length_penalty
self.bos_attr = args.bos_attr
self.bias_attr = args.bias_attr
assert self.bos_attr in ['', 'avg', 'cross']
assert self.bias_attr in ['', 'avg', 'cross']
# indexes
self.bos_index = args.bos_index
self.eos_index = args.eos_index
self.pad_index = args.pad_index
# attribute embeddings / bias
if self.bos_attr != '' or self.bias_attr != '':
self.register_buffer('attr_offset', args.attr_offset.clone())
self.register_buffer('attr_shifts', args.attr_shifts.clone())
if self.bos_attr != '':
n_bos_attr = sum(args.n_labels) if self.bos_attr == 'avg' else reduce(mul, args.n_labels, 1)
self.bos_attr_embeddings = nn.Embedding(n_bos_attr, self.emb_dim)
if self.bias_attr != '':
n_bias_attr = sum(args.n_labels) if self.bias_attr == 'avg' else reduce(mul, args.n_labels, 1)
self.bias_attr_embeddings = nn.Embedding(n_bias_attr, self.n_words)
# embedding layers
if self.share_encdec_emb:
logger.info("Sharing encoder and decoder input embeddings")
self.embeddings = encoder.embeddings
else:
self.embeddings = Embedding(self.n_words, self.emb_dim, padding_idx=self.pad_index)
self.embed_scale = math.sqrt(self.emb_dim)
self.embed_positions = PositionalEmbedding(
1024, self.emb_dim, self.pad_index,
left_pad=args.left_pad_target,
)
self.layers = nn.ModuleList()
for k in range(args.decoder_layers):
self.layers[k] = TransformerDecoderLayer(args)
# projection layers
proj = nn.Linear(self.emb_dim, self.n_words)
if self.share_decpro_emb:
logger.info("Sharing input embeddings and projection matrix in the decoder")
proj.weight = self.embeddings.weight
self.proj = proj