submission_code/tools.py (133 lines of code) (raw):
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as data
from transformers import BertConfig, BertModel, EncoderDecoderConfig, EncoderDecoderModel
class MtDataset(data.Dataset):
def __init__(self, src, tgt, config, bos_id, eos_id, pad_id):
self.src = nn.utils.rnn.pad_sequence([
torch.tensor([bos_id] + x[:config.max_src_len] + [eos_id]).long() for x in src],
batch_first=True, padding_value=pad_id)
self.tgt = nn.utils.rnn.pad_sequence([
torch.tensor([bos_id] + x[:config.max_tgt_len] + [eos_id]).long() for x in tgt],
batch_first=True, padding_value=pad_id)
def __len__(self):
return len(self.src)
def __getitem__(self, i):
return {
'features': (self.src[i], self.tgt[i][:-1]),
'targets': self.tgt[i][1:]}
class Transformer(nn.Module):
def __init__(self, config, pad_id):
super(Transformer, self).__init__()
encoder_config = BertConfig(
vocab_size=config.src_vocab_size,
hidden_size=config.h_size,
num_hidden_layers=config.enc_layers,
num_attention_heads=config.n_heads,
intermediate_size=config.d_ff,
hidden_dropout_prob = config.dropout,
pad_token_id=pad_id,
)
decoder_config = BertConfig(
vocab_size=config.tgt_vocab_size,
hidden_size=config.h_size,
num_hidden_layers=config.dec_layers,
num_attention_heads=config.n_heads,
intermediate_size=config.d_ff,
hidden_dropout_prob = config.dropout,
pad_token_id=pad_id,
is_decoder=True,
add_cross_attention=True,
)
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)
self.tr = EncoderDecoderModel(config=encoder_decoder_config)
if config.joined_vocab:
self.tr.encoder.embeddings.word_embeddings = self.tr.decoder.bert.embeddings.word_embeddings
def forward(self, x):
src, tgt = x
src_attn = (src != 0).float()
tgt_attn = (tgt != 0).float()
x = self.tr(
input_ids=src,
attention_mask=src_attn,
decoder_input_ids=tgt,
decoder_attention_mask=tgt_attn,
)
x = x[0].permute(0,2,1)
return x
def beam_search(src, model, pad_token, bos_id, end_token, max_len=10, k=5):
device = next(model.parameters()).device
src = src.view(1,-1).to(device)
src_mask = (src != pad_token).to(device)
memory = None
input_seq = [bos_id]
beam = [(input_seq, 0)]
for i in range(max_len):
candidates = []
candidates_proba = []
for snt, snt_proba in beam:
if snt[-1] == end_token:
candidates.append(snt)
candidates_proba.append(snt_proba)
else:
snt_tensor = torch.tensor(snt).view(1, -1).long().to(device)
if memory is None:
memory = model(
input_ids=src,
attention_mask=src_mask,
decoder_input_ids=snt_tensor,
)
else:
memory = model(
input_ids=src,
attention_mask=src_mask,
decoder_input_ids=snt_tensor,
encoder_outputs=(memory[1], memory[-1]),
)
proba = memory[0].cpu()[0,-1, :]
proba = torch.log_softmax(proba, dim=-1).numpy()
best_k = np.argpartition(-proba, k - 1)[:k]
for tok in best_k:
candidates.append(snt + [tok])
candidates_proba.append(snt_proba + proba[tok])
best_candidates = np.argpartition(-np.array(candidates_proba), k - 1)[:k]
beam = [(candidates[j], candidates_proba[j]) for j in best_candidates]
beam = sorted(beam, key=lambda x: -x[1])
return beam
class BertClassifier(nn.Module):
def __init__(self, config, pad_id, num_classes):
super(BertClassifier, self).__init__()
bert_config = BertConfig(
vocab_size=config.src_vocab_size,
hidden_size=config.h_size,
num_hidden_layers=config.n_layers,
num_attention_heads=config.n_heads,
intermediate_size=config.d_ff,
hidden_dropout_prob = config.dropout,
pad_token_id=pad_id,
)
self.tr = BertModel(config=bert_config)
self.drop = nn.Dropout(config.dropout)
self.out = nn.Linear(config.h_size, num_classes)
def forward(self, x):
attn = (x != 0).float()
x = self.tr(
input_ids=x,
attention_mask=attn,
return_dict=True
)
x = x.last_hidden_state.mean(dim=1)
x = self.drop(x)
x = self.out(x)
return x
class UtilDataset(data.Dataset):
def __init__(self, src, tgt, config, bos_id, eos_id, pad_id):
self.src = nn.utils.rnn.pad_sequence([
torch.tensor([bos_id] + x[:config.max_src_len] + [eos_id]).long() for x in src],
batch_first=True, padding_value=pad_id)
self.tgt = torch.tensor(tgt.values).long()
def __len__(self):
return len(self.src)
def __getitem__(self, i):
return {'features': self.src[i], 'targets': self.tgt[i]}