in models.py [0:0]
def visualize(self, sent, tokenize=True):
sent = sent.split() if not tokenize else self.tokenize(sent)
sent = [[self.bos] + [word for word in sent if word in self.word_vec] + [self.eos]]
if ' '.join(sent[0]) == '%s %s' % (self.bos, self.eos):
import warnings
warnings.warn('No words in "%s" have w2v vectors. Replacing \
by "%s %s"..' % (sent, self.bos, self.eos))
batch = self.get_batch(sent)
if self.is_cuda():
batch = batch.cuda()
output = self.enc_lstm(batch)[0]
output, idxs = torch.max(output, 0)
# output, idxs = output.squeeze(), idxs.squeeze()
idxs = idxs.data.cpu().numpy()
argmaxs = [np.sum((idxs == k)) for k in range(len(sent[0]))]
# visualize model
import matplotlib.pyplot as plt
x = range(len(sent[0]))
y = [100.0 * n / np.sum(argmaxs) for n in argmaxs]
plt.xticks(x, sent[0], rotation=45)
plt.bar(x, y)
plt.ylabel('%')
plt.title('Visualisation of words importance')
plt.show()
return output, idxs