def _assemble_program()

in models_mnist/assembler.py [0:0]


  def _assemble_program(self, image, text, fact, text_feat, tokens, reuse_stack):
    # aliases
    weaver = self.weaver
    executor = self.executor

    # get extent of layout tokens
    max_time, batch_size = tokens.shape
    num_rounds = executor.params['num_rounds']

    outputs = []
    validity = []
    # for visualizing internal nodes
    vis_outputs = {'att': [], 'weights': []}
    for r_id in range(num_rounds):
      layout = tokens[:, r_id]
      invalid_prog = False
      round_id = weaver.batch_input(executor._loom_types['round'], r_id)
      if fact is not None: fact_slice = weaver.slice_fact(fact, round_id)

      # valid layout must contain <eos>. Assembly fails if it doesn't.
      if not np.any(layout == self.EOS_idx): invalid_prog = True

      decode_stack = []
      penult_out = None # penultimate output
      for t_id in range(len(layout)):
        weights = None
        time = weaver.batch_input(executor._loom_types['time'], t_id)
        text_att = weaver.slice_text(text, round_id, time)

        # slice the text feature
        text_feat_slice = weaver.slice_text_feat(text_feat, round_id, time)

        cur_op_id = layout[t_id]
        cur_op_name = self.module_names[cur_op_id]

        # <eos> would mean stop
        if cur_op_id == self.EOS_idx: break

        # insufficient number of inputs
        num_inputs = _module_input_num[cur_op_name]
        if len(decode_stack) < num_inputs:
          invalid_prog = True
          break

        # read the inputs
        inputs = []
        for ii in range(num_inputs):
          arg, arg_type = decode_stack.pop()
          # cannot consume anything but attention
          if arg_type != 'att':
            invalid_prog = True
            break
          inputs.append(arg)

        # switch cases
        if cur_op_name == '_Find':
          out = weaver.find(image, text_att)

        elif cur_op_name == '_Refer':
          # nothing to refer to, wrong program
          if len(reuse_stack) == 0:
            invalid_prog = True
            break

          # if baseline is in the model, take the last output
          if 'baseline' in self.executor.params['model']:
            out = reuse_stack[-1][0]
          else:
            inputs = (text_feat_slice, round_id, reuse_stack)
            out, weights, logits = self.assemble_refer(*inputs)

        elif cur_op_name == '_Exclude':
          # clean up reuse stack to avoid current finds
          neat_stack = reuse_stack.copy()
          for prev_time in range(t_id - 1, 0, -1):
            if neat_stack[-1][-2] == prev_time: neat_stack.pop(-1)

          # nothing to exclude to, wrong program
          if len(neat_stack) == 0:
            invalid_prog = True
            break

          inputs = (text_att, round_id, neat_stack)
          out = self.assemble_exclude(*inputs)
          # collect in reuse stack
          #reuse_stack.append((out, text_att, round_id, r_id, t_id))

        elif cur_op_name == '_Transform':
          out = weaver.transform(inputs[0], image, text_att)

        elif cur_op_name == '_Describe':
          out = weaver.describe(inputs[0], image, text_att)
          # TODO: Do this more carefully!
          penult_out = arg

        elif cur_op_name == '_Exist':
          out = weaver.exist(inputs[0], image, text_att)
          # TODO: Do this more carefully!
          penult_out = arg

        elif cur_op_name == '_Count':
          out = weaver.count(inputs[0], image, text_att)
          # TODO: Do this more carefully!
          penult_out = arg

        elif cur_op_name == '_And':
          out = weaver.and_op(inputs[0], inputs[1])

        elif cur_op_name == '_Diff':
          out = weaver.diff_op(inputs[0], inputs[1])

        # just invert the attention
        elif cur_op_name == '_Not':
          out = weaver.normalize_exclude(inputs[0])

        else:
          print('Current operand not defined: ' + cur_op_name)
          invalid_prog = True

        # collect outputs from all modules (visualize)
        if self.visualize:
          if _module_output_type[cur_op_name] == 'att':
            vis_outputs['att'].append((out, r_id))
          if weights is not None:
            vis_outputs['weights'].extend(weights)

        decode_stack.append((out, _module_output_type[cur_op_name]))

      # Check if only one element is left
      if len(decode_stack) != 1: invalid_prog = True
      # final output is not answer type
      elif decode_stack[0][1] != 'ans': invalid_prog = True

      # record program validity
      validity.append(invalid_prog)

      # if program is invalid, return zeros
      if invalid_prog: outputs.append(weaver.invalid(image))
      else:
        outputs.append(decode_stack[-1][0])

        # if fact is to be used, take the penultimate output
        if executor.params['use_fact']:
          reuse_stack.append((penult_out, fact_slice, round_id, r_id, -1))

    return {'comp': outputs, 'vis': vis_outputs}, reuse_stack, validity