def show_samples()

in ttw/models/language.py [0:0]


    def show_samples(self, dataset, num_samples=10, cuda=True, logger=None, decoding_strategy='sample',
                     indices=None, beam_width=4):
        if indices is None:
            indices = list()
            for _ in range(num_samples):
                indices.append(random.randint(0, len(dataset) - 1))

        collate_fn = get_collate_fn(cuda)

        data = [dataset[ind] for ind in indices]
        batch = collate_fn(data)

        out = self.forward(batch, decoding_strategy=decoding_strategy, train=False, beam_width=beam_width)

        generated_utterance = out['utterance'].cpu().data
        logger_fn = print
        if logger:
            logger_fn = logger

        for i in range(len(indices)):
            o = ''
            for obs in data[i]['goldstandard']:
                o += '(' + ','.join([dataset.map.landmark_dict.decode(o_ind) for o_ind in obs]) + ') ,'
            # a = ', '.join([i2act[a_ind] for a_ind in actions[i]])
            a = ','.join([dataset.act_dict.decode(a_ind) for a_ind in data[i]['actions']])

            logger_fn('Observations: ' + o)
            logger_fn('Actions: ' + a)
            logger_fn('GT: ' + dataset.dict.decode(batch['utterance'][i, 1:]))
            logger_fn('Sample: ' + dataset.dict.decode(generated_utterance[i, :]))
            logger_fn('-' * 80)