in ttw/models/language.py [0:0]
def show_samples(self, dataset, num_samples=10, cuda=True, logger=None, decoding_strategy='sample',
indices=None, beam_width=4):
if indices is None:
indices = list()
for _ in range(num_samples):
indices.append(random.randint(0, len(dataset) - 1))
collate_fn = get_collate_fn(cuda)
data = [dataset[ind] for ind in indices]
batch = collate_fn(data)
out = self.forward(batch, decoding_strategy=decoding_strategy, train=False, beam_width=beam_width)
generated_utterance = out['utterance'].cpu().data
logger_fn = print
if logger:
logger_fn = logger
for i in range(len(indices)):
o = ''
for obs in data[i]['goldstandard']:
o += '(' + ','.join([dataset.map.landmark_dict.decode(o_ind) for o_ind in obs]) + ') ,'
# a = ', '.join([i2act[a_ind] for a_ind in actions[i]])
a = ','.join([dataset.act_dict.decode(a_ind) for a_ind in data[i]['actions']])
logger_fn('Observations: ' + o)
logger_fn('Actions: ' + a)
logger_fn('GT: ' + dataset.dict.decode(batch['utterance'][i, 1:]))
logger_fn('Sample: ' + dataset.dict.decode(generated_utterance[i, :]))
logger_fn('-' * 80)