in vis/visualize_dialogs.py [0:0]
def main(args):
titles = ['Image', 'Answers', 'Predictions', 'Modules', 'Attention']
# load the batch
data = np.load(args.batch_path)[()]
batch, outputs = data['batch'], data['output']
# load dictionary
with open(args.text_vocab_path, 'r') as file_id:
word2ind = {word.strip('\n'): ind
for ind, word in enumerate(file_id.readlines())}
ind2word = {ind: word for word, ind in word2ind.items()}
# get the program dictionary
with open(args.prog_vocab_path, 'r') as file_id:
word2ind_prog = {word.strip('\n'): ind
for ind, word in enumerate(file_id.readlines())}
ind2word_prog = {ind: word for word, ind in word2ind_prog.items()}
stringify = lambda vector: ' '.join([ind2word[w] for w in vector])
stringify_prog = lambda vector: ' '.join([ind2word_prog[w] for w in vector])
# Get html related info
page = html.HTML(len(titles))
page.set_title(titles)
template = 'Q%d: %s\nA [GT]: %s\nP [GT]: %s\nP: %s'
pred_template = 'GT Rank: %d\n_top-5: \n%s'
# saving intermediate outputs
end_prog_token = word2ind_prog['<eos>']
server_save = './attention/%d_%d_%d_%d.png'
local_save = os.path.join(args.image_save_root, 'attention/%d_%d_%d_%d.png')
# Create folders.
os.makedirs(args.image_save_root, exist_ok=True)
os.makedirs(os.path.join(args.image_save_root, 'attention'), exist_ok=True)
for ii in progressbar(range(args.num_examples)):
# Read image.
img_name = '/'.join(batch[ii]['img_path'][0].split('/')[-2:])
image = io.imread(os.path.join(args.image_load_root, img_name))
# Deal with black and white images.
if len(image.shape) < 3:
image = np.expand_dims(image, -1)
image = np.tile(image, [1, 1, 3])
# Caption.
if batch[ii]['cap_len'].ndim == 2:
cap_len = batch[ii]['cap_len'][0]
cap_string = stringify(batch[ii]['cap'][0, :cap_len])
else:
cap_len = batch[ii]['cap_len'][0]
cap_string = stringify(batch[ii]['cap'][0, :cap_len])
span_content = page.link_image('coco_images/' + img_name, cap_string, 400)
# decide length based on first appearance of 14 <eos>
if 'pred_tokens_cap' in outputs[ii]:
caption_prog = outputs[ii]['pred_tokens_cap']
prog_len = np.where(caption_prog[:, 0] == end_prog_token)[0][0]
cap_tokens = [ind2word[w] for w in batch[ii]['cap'][0, :cap_len]]
prog_tokens = [ind2word_prog[w] for w in caption_prog[:prog_len, 0]]
att = 100 * outputs[ii]['attention_cap'][:, :, 0, 0].transpose()
word_att_str = page.add_question_attention(cap_tokens, prog_tokens, att)
# caption module outputs
stack = outputs[ii]['intermediates'][0]
cap_stack = [datum for datum in stack if datum[0] == 'cap']
string = {'c_1':'', 'c_2':''}
for _, step, _, attention in cap_stack:
# reshape and renormalize
att = attention[:, :, 0]
att_image = support.get_blend_map(image, att)
att_image = Image.fromarray(np.uint8(att_image))
att_image = att_image.resize((200, 200))
att_image.save(local_save % (2, ii, 0, step), 'png')
# caption first row
string['c_1'] += page.link_image(server_save % (2, ii, 0, step))
att = attention[:, :, 0]
att_image = support.interpolate_attention(image, att)
#att_image = support.get_blend_map(image, att)
att_image = Image.fromarray(np.uint8(att_image))
att_image = att_image.resize((200, 200))
att_image.save(local_save % (3, ii, 0, step), 'png')
# caption second row
string['c_2'] += page.link_image(server_save % (3, ii, 0, step))
# add the neural module visualization for captions
span_content += '\n'.join(['', string['c_1'], string['c_2'], word_att_str])
ques_content = []
for jj in range(10):
row_content = []
# question
ques_len = batch[ii]['ques_len'][jj]
ques_string = stringify(batch[ii]['ques'][:ques_len, jj])
# answer
ans_len = batch[ii]['ans_len'][jj]
ans_in = stringify(batch[ii]['ans_in'][jj, :ans_len])
ans_out = stringify(batch[ii]['ans_out'][jj, :ans_len])
# program
gt_prog_str = stringify_prog(batch[ii]['gt_layout'][:, jj])
cur_prog = outputs[ii]['pred_tokens'][:, jj]
prog_pred = stringify_prog(outputs[ii]['pred_tokens'][:, jj])
print_slot = (jj, ques_string, ans_in, gt_prog_str, prog_pred)
row_content.append(template % print_slot)
# get predictions
sort_arg = np.argsort(outputs[ii]['scores'][jj])[::-1][:args.top_options]
gt_score = outputs[ii]['scores'][jj][batch[ii]['gt_ind'][jj]]
gt_rank = np.sum(outputs[ii]['scores'][jj] > gt_score) + 1
options = [stringify(batch[ii]['opt_in'][kk][jj]) for kk in sort_arg]
row_content.append(pred_template % (gt_rank, '\n'.join(options)))
# visualizing intermediate outputs for each question
stack = outputs[ii]['intermediates'][0]
ques_stack = [datum for datum in stack
if (datum[0] == 'ques') and (datum[2] == jj)]
string = {'q_1':'', 'q_2':''}
for _, step, _, attention in ques_stack:
# reshape and renormalize
att = attention[:, :, 0]
#att_image = support.interpolate_attention(image, att)
att_image = support.get_blend_map(image, att)
att_image = Image.fromarray(np.uint8(att_image))
att_image = att_image.resize((200, 200))
att_image.save(local_save % (0, ii, jj, step), 'png')
# string for first row
string['q_1'] += page.link_image(server_save % (0, ii, jj, step))
att = attention[:, :, 0]
att_image = support.interpolate_attention(image, att)
#att_image = support.get_blend_map(image, att)
att_image = Image.fromarray(np.uint8(att_image))
att_image = att_image.resize((200, 200))
att_image.save(local_save % (1, ii, jj, step), 'png')
# string for second row
string['q_2'] += page.link_image(server_save % (1, ii, jj, step))
# if refer module, add weights
if ind2word_prog[cur_prog[step]] == '_Refer':
wt_stack = outputs[ii]['intermediates'][1]
cur_wt = [datum for datum in wt_stack if datum[0] == jj]
assert (len(cur_wt) == 1), 'Weights over history do not sum to one'
wts = cur_wt[0][1]
wt_labels = cur_wt[0][2]
if len(wts) > 0:
string['q_1'] = page.add_history_attention(wts, wt_labels)
string['q_1'] += ('\n' + string['q_1'])
row_content.append('\n'.join(['', string['q_1'], string['q_2']]))
# decide length based on first appearance of 14 <eos>
ques_prog = outputs[ii]['pred_tokens'][:, jj]
prog_len = np.where(ques_prog == end_prog_token)[0][0]
ques_tokens = [ind2word[w] for w in batch[ii]['ques'][:ques_len, jj]]
prog_tokens = [ind2word_prog[w] for w in ques_prog[:prog_len]]
att = 100 * outputs[ii]['attention'][:, :, jj, 0].transpose()
string = page.add_question_attention(ques_tokens, prog_tokens, att)
row_content.append(string)
ques_content.append(row_content)
# Add the span row
page.add_spanning_row(span_content, ques_content)
# render page and save
page.save_page(args.save_path)