def segregrate_outputs()

in models_vd/executor.py [0:0]


  def segregrate_outputs(self, output):
    '''
      Go over the outputs, cap tokens and ques tokens
    '''
    if 'nmn-cap' in self.params['model']:
      cap_tokens = output['pred_tokens_cap'][:, 0]
    ques_tokens = output['pred_tokens']
    mod_out_type = _module_output_type
    mod_dict = self._assembler.module_names

    att = output['att']
    # logits -> weights when visualizing
    weights = output['logits']

    # segregrated outputs
    sep_att = []
    sep_wts = []
    wt_labels = []
    num_reuse = 0
    att_ind = 0
    weight_ind = 0
    # go over caption
    if 'nmn-cap' in self.params['model']:
      for t_id in range(self.params['max_dec_len']):
        cur_module = mod_dict[cap_tokens[t_id]]
        if cur_module == '<eos>': break
        if mod_out_type[cur_module] == 'att':
          sep_att.append(('cap', t_id, 0, att[att_ind]))
          att_ind += 1

          if cur_module == '_Find':
            wt_labels.append('C_%d' % t_id)
            num_reuse += 1

    # assume a batch size of 1
    for r_id in range(self.params['num_rounds']):
      for t_id in range(self.params['max_dec_len']):
        cur_module = mod_dict[ques_tokens[t_id, r_id]]
        if cur_module == '<eos>':
          # even answer has a weight now
          if self.params['use_fact']:
            wt_labels.append('A%d' % r_id)
            num_reuse += 1
          break

        if mod_out_type[cur_module] == 'att':
          sep_att.append(('ques', t_id, r_id, att[att_ind]))
          att_ind += 1

        if cur_module == '_Refer':
          st = weight_ind
          end = weight_ind + num_reuse
          sep_wts.append((r_id, weights[st:end], wt_labels))
          weight_ind += num_reuse

        if cur_module == '_Find':
          wt_labels.append('Q%d_%d' % (r_id, t_id))
          num_reuse += 1

    # do not assert if baseline
    if 'baseline' in self.params['model']:
      return sep_att, sep_wts

    for arg in sep_wts:
      assert(abs(np.sum(arg[1]) - 1.0) < 1e-5)

    # Sanity checks to ensure Refer is not doing anything weird.
    assert(weight_ind == weights.shape[0])
    assert(att_ind == att.shape[0])

    return sep_att, sep_wts