in models_vd/executor.py [0:0]
def segregrate_outputs(self, output):
'''
Go over the outputs, cap tokens and ques tokens
'''
if 'nmn-cap' in self.params['model']:
cap_tokens = output['pred_tokens_cap'][:, 0]
ques_tokens = output['pred_tokens']
mod_out_type = _module_output_type
mod_dict = self._assembler.module_names
att = output['att']
# logits -> weights when visualizing
weights = output['logits']
# segregrated outputs
sep_att = []
sep_wts = []
wt_labels = []
num_reuse = 0
att_ind = 0
weight_ind = 0
# go over caption
if 'nmn-cap' in self.params['model']:
for t_id in range(self.params['max_dec_len']):
cur_module = mod_dict[cap_tokens[t_id]]
if cur_module == '<eos>': break
if mod_out_type[cur_module] == 'att':
sep_att.append(('cap', t_id, 0, att[att_ind]))
att_ind += 1
if cur_module == '_Find':
wt_labels.append('C_%d' % t_id)
num_reuse += 1
# assume a batch size of 1
for r_id in range(self.params['num_rounds']):
for t_id in range(self.params['max_dec_len']):
cur_module = mod_dict[ques_tokens[t_id, r_id]]
if cur_module == '<eos>':
# even answer has a weight now
if self.params['use_fact']:
wt_labels.append('A%d' % r_id)
num_reuse += 1
break
if mod_out_type[cur_module] == 'att':
sep_att.append(('ques', t_id, r_id, att[att_ind]))
att_ind += 1
if cur_module == '_Refer':
st = weight_ind
end = weight_ind + num_reuse
sep_wts.append((r_id, weights[st:end], wt_labels))
weight_ind += num_reuse
if cur_module == '_Find':
wt_labels.append('Q%d_%d' % (r_id, t_id))
num_reuse += 1
# do not assert if baseline
if 'baseline' in self.params['model']:
return sep_att, sep_wts
for arg in sep_wts:
assert(abs(np.sum(arg[1]) - 1.0) < 1e-5)
# Sanity checks to ensure Refer is not doing anything weird.
assert(weight_ind == weights.shape[0])
assert(att_ind == att.shape[0])
return sep_att, sep_wts