def forward()

in ttw/models/language.py [0:0]


    def forward(self, batch, decoding_strategy='beam_search',
                max_sample_length=20, beam_width=4, train=True):
        batch_size = batch['goldstandard'].size(0)
        obs_seq_len = batch['goldstandard_mask'][:, :, 0].sum(1).long()
        if batch['actions_mask'].dim() > 1:
            act_seq_len = batch['actions_mask'].sum(1).long()
        else:
            act_seq_len = Variable(torch.LongTensor(batch_size).fill_(0)).cuda()
        context_emb = self.encode(batch['goldstandard'], obs_seq_len, batch['actions'], act_seq_len)

        if train:
            # teacher forcing
            assert('utterance_mask' in batch.keys() and 'utterance' in batch.keys())
            inp = batch['utterance'][:, :-1]
            tgt = batch['utterance'][:, 1:]

            inp_emb = self.emb_fn.forward(inp)

            # concatenate external emb
            context_emb = context_emb.view(batch_size, 1, self.decoder_emb_sz).repeat(1, inp_emb.size(1), 1)
            inp_emb = torch.cat([inp_emb, context_emb], 2)

            hs, _ = self.decoder(inp_emb)

            score = self.out_linear(hs)

            loss = 0.0
            mask = batch['utterance_mask'][:, 1:]

            for j in range(score.size(1)):
                flat_mask = mask[:, j]
                flat_score = score[:, j, :]
                flat_tgt = tgt[:, j]
                nll = self.loss(flat_score, flat_tgt)
                loss += (flat_mask*nll).sum()

            out = {}
            out['loss'] = loss
        else:
            if decoding_strategy in ['greedy', 'sample']:
                preds = []
                probs = []

                input_ind = torch.LongTensor([self.start_token] * batch_size)
                hs = Variable(torch.FloatTensor(1, batch_size, self.decoder_hid_sz).fill_(0.0))
                mask = Variable(torch.FloatTensor(batch_size, max_sample_length).zero_())
                eos = torch.ByteTensor([0]*batch_size)
                if batch['goldstandard'].is_cuda:
                    hs = hs.cuda()
                    eos = eos.cuda()
                    mask = mask.cuda()
                    input_ind = input_ind.cuda()

                for k in range(max_sample_length):
                    inp_emb = self.emb_fn.forward(input_ind.unsqueeze(-1))

                    context_emb = context_emb.view(batch_size, 1, self.decoder_emb_sz).repeat(1, inp_emb.size(1), 1)
                    inp_emb = torch.cat([inp_emb, context_emb], 2)

                    _, hs = self.decoder(inp_emb, hs)

                    prob = F.softmax(self.out_linear(hs.squeeze(0)), dim=-1)
                    if decoding_strategy == 'greedy':
                        _, samples = prob.max(1)
                        samples = samples.unsqueeze(-1)
                    else:
                        samples = prob.multinomial(1)
                    mask[:, k] = 1.0 - eos.float()

                    eos = eos | (samples == self.end_token).squeeze()

                    preds.append(samples)
                    probs.append(prob.unsqueeze(1))
                    input_ind = samples.squeeze(-1)

                out = {}
                out['utterance'] = torch.cat(preds, 1)
                out['utterance_mask'] = mask
                out['probs'] = torch.cat(probs, 1)
            elif decoding_strategy == 'beam_search':
                def _step_fn(input, hidden, context, k=4):
                    input = Variable(torch.LongTensor(input)).squeeze().cuda()
                    hidden = Variable(torch.FloatTensor(hidden)).unsqueeze(0).cuda()
                    context = Variable(torch.FloatTensor(context)).unsqueeze(1).cuda()

                    prob, hs = self.step(input, hidden, context)

                    logprobs = torch.log(prob)
                    logprobs, words = logprobs.topk(k, 1)
                    hs = hs.squeeze().cpu().data.numpy()

                    return words, logprobs, hs

                seq_gen = SequenceGenerator(_step_fn, self.end_token, max_sequence_length=max_sample_length,
                                            beam_size=beam_width, length_normalization_factor=0.5)
                start_tokens = [[self.start_token] for _ in range(batch_size)]
                hidden = [[0.0]*self.decoder_hid_sz]*batch_size
                beam_out = seq_gen.beam_search(start_tokens, hidden, context_emb.cpu().data.numpy())
                pred_tensor = torch.LongTensor(batch_size, max_sample_length).zero_()
                mask_tensor = torch.FloatTensor(batch_size, max_sample_length).zero_()

                for i, seq in enumerate(beam_out):
                    pred_tensor[i, :(len(seq.output)-1)] = torch.LongTensor(seq.output[1:])
                    mask_tensor[i, :(len(seq.output)-1)] = 1.0

                out = {}
                out['utterance'] = Variable(pred_tensor)
                out['utterance_mask'] = Variable(mask_tensor)

                if batch['goldstandard'].is_cuda:
                    out['utterance'] = out['utterance'].cuda()
                    out['utterance_mask'] = out['utterance_mask'].cuda()

        return out