in models_vd/assembler.py [0:0]
def _assemble_program(self, image, text, fact, text_feat, tokens, reuse_stack):
# aliases
weaver = self.weaver
executor = self.executor
# get extent of layout tokens
max_time, batch_size = tokens.shape
num_rounds = executor.params['num_rounds']
outputs = []
validity = []
# for visualizing internal nodes
vis_outputs = {'att': [], 'weights': [], 'logits': []}
for r_id in range(num_rounds):
layout = tokens[:, r_id]
invalid_prog = False
round_id = weaver.batch_input(executor._loom_types['round'], r_id)
if fact is not None: fact_slice = weaver.slice_fact(fact, round_id)
# valid layout must contain <eos>. Assembly fails if it doesn't.
if not np.any(layout == self.EOS_idx): invalid_prog = True
decode_stack = []
penult_out = None # penultimate output
for t_id in range(len(layout)):
weights = None
time = weaver.batch_input(executor._loom_types['time'], t_id)
text_att = weaver.slice_text(text, round_id, time)
# slice the text feature
text_feat_slice = weaver.slice_text_feat(text_feat, round_id, time)
cur_op_id = layout[t_id]
cur_op_name = self.module_names[cur_op_id]
# <eos> would mean stop
if cur_op_id == self.EOS_idx: break
# insufficient number of inputs
num_inputs = _module_input_num[cur_op_name]
if len(decode_stack) < num_inputs:
invalid_prog = True
break
# read the inputs
inputs = []
for ii in range(num_inputs):
arg, arg_type = decode_stack.pop()
# cannot consume anything but attention
if arg_type != 'att':
invalid_prog = True
break
inputs.append(arg)
# switch cases
if cur_op_name == '_Find':
out = weaver.find(image, text_att)
# collect in reuse stack (always)
#if fact is None:
reuse_stack.append((out, text_feat_slice, round_id, r_id, t_id))
#reuse_stack.append((out, text_att, round_id, r_id, t_id))
if cur_op_name == '_Refer':
if len(reuse_stack) == 0:
print('Something wrong with Refer')
continue
# if baseline is in the model, take the last output
if 'baseline' in self.executor.params['model']:
out = reuse_stack[-1][0]
else:
inputs = (text_feat_slice, round_id, reuse_stack)
out, weights, logits = self.assemble_refer(*inputs)
if cur_op_name == '_Exclude':
# clean up reuse stack to avoid current finds
neat_stack = reuse_stack.copy()
for prev_time in range(t_id - 1, 0, -1):
if neat_stack[-1][-2] == prev_time: neat_stack.pop(-1)
inputs = (text_att, round_id, neat_stack)
out = self.assemble_exclude(*inputs)
# collect in reuse stack
#reuse_stack.append((out, text_att, round_id, r_id, t_id))
elif cur_op_name == '_Transform':
out = weaver.transform(inputs[0], image, text_att)
elif cur_op_name == '_Describe':
out = weaver.describe(inputs[0], image, text_att)
# TODO: Do this more carefully!
penult_out = arg
elif cur_op_name == '_And':
out = weaver.and_op(inputs[0], inputs[1])
# collect outputs from all modules (visualize)
if self.visualize:
if _module_output_type[cur_op_name] == 'att':
vis_outputs['att'].append((out, r_id))
if weights is not None:
vis_outputs['weights'].extend(weights)
#vis_outputs['logits'].extend(logits)
# also add weights to usual outputs
#if weights is not None: print(r_id, len(weights))
if weights is not None:
if executor.params['train_mode']: outputs.extend(logits)
decode_stack.append((out, _module_output_type[cur_op_name]))
# Check if only one element is left
if len(decode_stack) != 1: invalid_prog = True
# final output is not answer type
elif decode_stack[0][1] != 'ans': invalid_prog = True
# record program validity
validity.append(invalid_prog)
# if program is invalid, return zeros
if invalid_prog: outputs.append(weaver.invalid(image))
else:
outputs.append(decode_stack[-1][0])
if fact is not None:
# record fact embedding against penultimate output
reuse_stack.append((penult_out, fact_slice, round_id, r_id, -1))
return {'comp': outputs, 'vis': vis_outputs}, reuse_stack, validity