in ttw/models/language.py [0:0]
def forward(self, batch, decoding_strategy='beam_search',
max_sample_length=20, beam_width=4, train=True):
batch_size = batch['goldstandard'].size(0)
obs_seq_len = batch['goldstandard_mask'][:, :, 0].sum(1).long()
if batch['actions_mask'].dim() > 1:
act_seq_len = batch['actions_mask'].sum(1).long()
else:
act_seq_len = Variable(torch.LongTensor(batch_size).fill_(0)).cuda()
context_emb = self.encode(batch['goldstandard'], obs_seq_len, batch['actions'], act_seq_len)
if train:
# teacher forcing
assert('utterance_mask' in batch.keys() and 'utterance' in batch.keys())
inp = batch['utterance'][:, :-1]
tgt = batch['utterance'][:, 1:]
inp_emb = self.emb_fn.forward(inp)
# concatenate external emb
context_emb = context_emb.view(batch_size, 1, self.decoder_emb_sz).repeat(1, inp_emb.size(1), 1)
inp_emb = torch.cat([inp_emb, context_emb], 2)
hs, _ = self.decoder(inp_emb)
score = self.out_linear(hs)
loss = 0.0
mask = batch['utterance_mask'][:, 1:]
for j in range(score.size(1)):
flat_mask = mask[:, j]
flat_score = score[:, j, :]
flat_tgt = tgt[:, j]
nll = self.loss(flat_score, flat_tgt)
loss += (flat_mask*nll).sum()
out = {}
out['loss'] = loss
else:
if decoding_strategy in ['greedy', 'sample']:
preds = []
probs = []
input_ind = torch.LongTensor([self.start_token] * batch_size)
hs = Variable(torch.FloatTensor(1, batch_size, self.decoder_hid_sz).fill_(0.0))
mask = Variable(torch.FloatTensor(batch_size, max_sample_length).zero_())
eos = torch.ByteTensor([0]*batch_size)
if batch['goldstandard'].is_cuda:
hs = hs.cuda()
eos = eos.cuda()
mask = mask.cuda()
input_ind = input_ind.cuda()
for k in range(max_sample_length):
inp_emb = self.emb_fn.forward(input_ind.unsqueeze(-1))
context_emb = context_emb.view(batch_size, 1, self.decoder_emb_sz).repeat(1, inp_emb.size(1), 1)
inp_emb = torch.cat([inp_emb, context_emb], 2)
_, hs = self.decoder(inp_emb, hs)
prob = F.softmax(self.out_linear(hs.squeeze(0)), dim=-1)
if decoding_strategy == 'greedy':
_, samples = prob.max(1)
samples = samples.unsqueeze(-1)
else:
samples = prob.multinomial(1)
mask[:, k] = 1.0 - eos.float()
eos = eos | (samples == self.end_token).squeeze()
preds.append(samples)
probs.append(prob.unsqueeze(1))
input_ind = samples.squeeze(-1)
out = {}
out['utterance'] = torch.cat(preds, 1)
out['utterance_mask'] = mask
out['probs'] = torch.cat(probs, 1)
elif decoding_strategy == 'beam_search':
def _step_fn(input, hidden, context, k=4):
input = Variable(torch.LongTensor(input)).squeeze().cuda()
hidden = Variable(torch.FloatTensor(hidden)).unsqueeze(0).cuda()
context = Variable(torch.FloatTensor(context)).unsqueeze(1).cuda()
prob, hs = self.step(input, hidden, context)
logprobs = torch.log(prob)
logprobs, words = logprobs.topk(k, 1)
hs = hs.squeeze().cpu().data.numpy()
return words, logprobs, hs
seq_gen = SequenceGenerator(_step_fn, self.end_token, max_sequence_length=max_sample_length,
beam_size=beam_width, length_normalization_factor=0.5)
start_tokens = [[self.start_token] for _ in range(batch_size)]
hidden = [[0.0]*self.decoder_hid_sz]*batch_size
beam_out = seq_gen.beam_search(start_tokens, hidden, context_emb.cpu().data.numpy())
pred_tensor = torch.LongTensor(batch_size, max_sample_length).zero_()
mask_tensor = torch.FloatTensor(batch_size, max_sample_length).zero_()
for i, seq in enumerate(beam_out):
pred_tensor[i, :(len(seq.output)-1)] = torch.LongTensor(seq.output[1:])
mask_tensor[i, :(len(seq.output)-1)] = 1.0
out = {}
out['utterance'] = Variable(pred_tensor)
out['utterance_mask'] = Variable(mask_tensor)
if batch['goldstandard'].is_cuda:
out['utterance'] = out['utterance'].cuda()
out['utterance_mask'] = out['utterance_mask'].cuda()
return out