models_mnist/assembler.py [51:264]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
                       '_Describe': 'ans'
                      }

INVALID_EXPR = 'INVALID_EXPR'

# decoding validity: maintaining a state x of [#att, #ans, T_remain]
# when T_remain is T_decoder when decoding the first module token
# a token s can be predicted iff all(<x, w_s> - b_s >= 0)
# the validity token list is
#       XW - b >= 0
# the state transition matrix is P, so the state update is X += S P,
# where S is the predicted tokens (one-hot vectors)

def _build_validity_mats(module_names):
  state_size = 3
  num_vocab_nmn = len(module_names)
  num_constraints = 4
  P = np.zeros((num_vocab_nmn, state_size), np.int32)
  W = np.zeros((state_size, num_vocab_nmn, num_constraints), np.int32)
  b = np.zeros((num_vocab_nmn, num_constraints), np.int32)

  # collect the input and output numbers of each module
  att_in_nums = np.zeros(num_vocab_nmn)
  att_out_nums = np.zeros(num_vocab_nmn)
  ans_out_nums = np.zeros(num_vocab_nmn)
  for n_s, s in enumerate(module_names):
    if s != '<eos>':
      att_in_nums[n_s] = _module_input_num[s]
      att_out_nums[n_s] = _module_output_type[s] == 'att'
      ans_out_nums[n_s] = _module_output_type[s] == 'ans'
  # construct the trasition matrix P
  for n_s, s in enumerate(module_names):
    P[n_s, 0] = att_out_nums[n_s] - att_in_nums[n_s]
    P[n_s, 1] = ans_out_nums[n_s]
    P[n_s, 2] = -1
  # construct the validity W and b
  att_absorb_nums = (att_in_nums - att_out_nums)
  max_att_absorb_nonans = np.max(att_absorb_nums * (ans_out_nums == 0))
  max_att_absorb_ans = np.max(att_absorb_nums * (ans_out_nums != 0))
  for n_s, s in enumerate(module_names):
    if s != '<eos>':
      # constraint: a non-<eos> module can be outputted iff all the following 
      # hold:
      # * 0) there's enough att in the stack
      #      #att >= att_in_nums[n_s]
      W[0, n_s, 0] = 1
      b[n_s, 0] = att_in_nums[n_s]

      # * 1) for answer modules, there's no extra att in the stack
      #      #att <= att_in_nums[n_s]
      #      -#att >= -att_in_nums[n_s]
      #      for non-answer modules, T_remain >= 3
      #      (the last two has to be AnswerType and <eos>)
      if ans_out_nums[n_s] != 0:
        W[0, n_s, 1] = -1
        b[n_s, 1] = -att_in_nums[n_s]
      else:
        W[2, n_s, 1] = 1
        b[n_s, 1] = 3

      # * 2) there's no answer in the stack (otherwise <eos> only)
      #      #ans <= 0
      #      -#ans >= 0
      W[1, n_s, 2] = -1

      # * 3) there's enough time to consume the all attentions, output answer 
      #      plus <eos>
      #      3.1) for non-answer modules, we already have T_remain>= 3 from 
      #           constraint 2
      #           In maximum (T_remain-3) further steps
      #           (plus 3 steps for this, ans, <eos>) to consume atts
      #           (T_remain-3) * max_att_absorb_nonans + max_att_absorb_ans + 
      #            att_absorb_nums[n_s] >= #att
      #           T_remain*MANA - #att >= 3*MANA - MAA - A[s]
      #           - #att + MANA * T_remain >= 3*MANA - MAA - A[s]
      #      3.2) for answer modules, if it can be decoded then constraint 0&1 
      #           ensures that there'll be no att left in stack after decoding 
      #           this answer, hence no further constraints here
      if ans_out_nums[n_s] == 0:
        W[0, n_s, 3] = -1
        W[2, n_s, 3] = max_att_absorb_nonans
        b[n_s, 3] = (3 * max_att_absorb_nonans - max_att_absorb_ans -
                    att_absorb_nums[n_s])

    else:  # <eos>-case
      # constraint: a <eos> token can be outputted iff all the following holds
      # * 0) there's ans in the stack
      #      #ans >= 1
      W[1, n_s, 0] = 1
      b[n_s, 0] = 1

  return P, W, b
#------------------------------------------------------------------------------

class Assembler:
  def __init__(self, module_vocab_file):
    # read the module list, and record the index of each module and <eos>
    with open(module_vocab_file) as f:
      self.module_names = [s.strip() for s in f.readlines()]

    # find the index of <eos>
    for n_s in range(len(self.module_names)):
      if self.module_names[n_s] == '<eos>':
        self.EOS_idx = n_s
        break
    # build a dictionary from module name to token index
    self.name2idx_dict = {name: n_s for n_s, name in enumerate(self.module_names)}
    self.num_vocab_nmn = len(self.module_names)

    self.P, self.W, self.b = _build_validity_mats(self.module_names)

  def module_list2tokens(self, module_list, T=None):
    layout_tokens = [self.name2idx_dict[name] for name in module_list]
    if T is not None:
      if len(module_list) >= T:
        raise ValueError('Not enough time steps to add <eos>')
      layout_tokens += [self.EOS_idx]*(T-len(module_list))
    return layout_tokens

  def _layout_tokens2str(self, layout_tokens):
    return ' '.join([self.module_names[idx] for idx in layout_tokens])

  def assemble_refer(self, text_att, round_id, reuse_stack):
    # aliases
    weaver = self.weaver
    executor = self.executor

    # compute the scores
    logits = []
    for find_arg in reuse_stack:
      # compute the weights for each of the attention map
      inputs = (text_att, find_arg[1], round_id, find_arg[2])
      logits.append(weaver.align_text(*inputs))

    # exponential each logit
    weights = []
    for ii in logits: weights.append(weaver.exp(ii))

    # normalize the weights
    if len(weights) < 2:
      norm = weights[0]
    else:
      norm = weaver.add(weights[0], weights[1])
      for ii in weights[2:]: norm = weaver.add(norm, ii)
    for index, ii in enumerate(weights):
      weights[index] = weaver.divide(ii, norm)

    # multiply the attention with softmax weight
    prev_att = []
    for (att, _, _, _, _), weight in zip(reuse_stack, weights):
      prev_att.append(weaver.weight_attention(att, weight))

    # add all attentions to get the result
    if len(prev_att) < 2: out = prev_att[0]
    else:
      out = weaver.add_attention(prev_att[0], prev_att[1])
      for ii in prev_att[2:]:
        out = weaver.add_attention(out, ii)

    return out, weights, logits

  def assemble_exclude(self, text_att, round_id, reuse_stack):
    # aliases
    weaver = self.weaver
    executor = self.executor

    # compute the scores
    weights = []
    exclude_att = reuse_stack[0][0]
    if len(reuse_stack) > 1:
      for find_arg in reuse_stack:
        exclude_att = weaver.max_attention(exclude_att, find_arg[0])

    return weaver.normalize_exclude(exclude_att)

  # code to check if the program makes sense
  # typically contains all the checks from the _assemble_program method
  def sanity_check_program(self, layout):
    decode_stack = []
    for t_id, cur_op_id in enumerate(layout):
      cur_op_name = self.module_names[cur_op_id]
      # <eos> would mean stop
      if cur_op_id == self.EOS_idx: break

      # insufficient number of inputs
      num_inputs = _module_input_num[cur_op_name]
      if len(decode_stack) < num_inputs:
        return False, 'Insufficient inputs'

      # read the inputs
      inputs = []
      for ii in range(num_inputs):
        arg_type = decode_stack.pop()
        # cannot consume anything but attention
        if arg_type != 'att':
          return False, 'Intermediate not attention'

      decode_stack.append(_module_output_type[cur_op_name])

    # Check if only one element is left
    if len(decode_stack) != 1:
      return False, 'Left with more than one outputs'
    # final output is not answer type
    elif decode_stack[0] != 'ans':
      return False, 'Final output not an answer'

    return True, 'Valid program'

  def assemble(self, layout_tokens, executor, visualize=False):
    # layout_tokens_batch is a numpy array with shape [T, N],
    # containing module tokens and <eos>, in Reverse Polish Notation.

    # internalize executor and weaver
    self.executor = executor
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



models_vd/assembler.py [43:255]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
                       '_Describe': 'ans'
                      }

INVALID_EXPR = 'INVALID_EXPR'

# decoding validity: maintaining a state x of [#att, #ans, T_remain]
# when T_remain is T_decoder when decoding the first module token
# a token s can be predicted iff all(<x, w_s> - b_s >= 0)
# the validity token list is
#       XW - b >= 0
# the state transition matrix is P, so the state update is X += S P,
# where S is the predicted tokens (one-hot vectors)

def _build_validity_mats(module_names):
  state_size = 3
  num_vocab_nmn = len(module_names)
  num_constraints = 4
  P = np.zeros((num_vocab_nmn, state_size), np.int32)
  W = np.zeros((state_size, num_vocab_nmn, num_constraints), np.int32)
  b = np.zeros((num_vocab_nmn, num_constraints), np.int32)

  # collect the input and output numbers of each module
  att_in_nums = np.zeros(num_vocab_nmn)
  att_out_nums = np.zeros(num_vocab_nmn)
  ans_out_nums = np.zeros(num_vocab_nmn)
  for n_s, s in enumerate(module_names):
    if s != '<eos>':
      att_in_nums[n_s] = _module_input_num[s]
      att_out_nums[n_s] = _module_output_type[s] == 'att'
      ans_out_nums[n_s] = _module_output_type[s] == 'ans'
  # construct the trasition matrix P
  for n_s, s in enumerate(module_names):
    P[n_s, 0] = att_out_nums[n_s] - att_in_nums[n_s]
    P[n_s, 1] = ans_out_nums[n_s]
    P[n_s, 2] = -1
  # construct the validity W and b
  att_absorb_nums = (att_in_nums - att_out_nums)
  max_att_absorb_nonans = np.max(att_absorb_nums * (ans_out_nums == 0))
  max_att_absorb_ans = np.max(att_absorb_nums * (ans_out_nums != 0))
  for n_s, s in enumerate(module_names):
    if s != '<eos>':
      # constraint: a non-<eos> module can be outputted iff all the following 
      # hold:
      # * 0) there's enough att in the stack
      #      #att >= att_in_nums[n_s]
      W[0, n_s, 0] = 1
      b[n_s, 0] = att_in_nums[n_s]

      # * 1) for answer modules, there's no extra att in the stack
      #      #att <= att_in_nums[n_s]
      #      -#att >= -att_in_nums[n_s]
      #      for non-answer modules, T_remain >= 3
      #      (the last two has to be AnswerType and <eos>)
      if ans_out_nums[n_s] != 0:
        W[0, n_s, 1] = -1
        b[n_s, 1] = -att_in_nums[n_s]
      else:
        W[2, n_s, 1] = 1
        b[n_s, 1] = 3

      # * 2) there's no answer in the stack (otherwise <eos> only)
      #      #ans <= 0
      #      -#ans >= 0
      W[1, n_s, 2] = -1

      # * 3) there's enough time to consume the all attentions, output answer 
      #      plus <eos>
      #      3.1) for non-answer modules, we already have T_remain>= 3 from 
      #           constraint 2
      #           In maximum (T_remain-3) further steps
      #           (plus 3 steps for this, ans, <eos>) to consume atts
      #           (T_remain-3) * max_att_absorb_nonans + max_att_absorb_ans + 
      #            att_absorb_nums[n_s] >= #att
      #           T_remain*MANA - #att >= 3*MANA - MAA - A[s]
      #           - #att + MANA * T_remain >= 3*MANA - MAA - A[s]
      #      3.2) for answer modules, if it can be decoded then constraint 0&1 
      #           ensures that there'll be no att left in stack after decoding 
      #           this answer, hence no further constraints here
      if ans_out_nums[n_s] == 0:
        W[0, n_s, 3] = -1
        W[2, n_s, 3] = max_att_absorb_nonans
        b[n_s, 3] = (3 * max_att_absorb_nonans - max_att_absorb_ans -
                    att_absorb_nums[n_s])

    else:  # <eos>-case
      # constraint: a <eos> token can be outputted iff all the following holds
      # * 0) there's ans in the stack
      #      #ans >= 1
      W[1, n_s, 0] = 1
      b[n_s, 0] = 1

  return P, W, b
#------------------------------------------------------------------------------

class Assembler:
  def __init__(self, module_vocab_file):
    # read the module list, and record the index of each module and <eos>
    with open(module_vocab_file) as f:
      self.module_names = [s.strip() for s in f.readlines()]
    # find the index of <eos>
    for n_s in range(len(self.module_names)):
      if self.module_names[n_s] == '<eos>':
        self.EOS_idx = n_s
        break
    # build a dictionary from module name to token index
    self.name2idx_dict = {name: n_s for n_s, name in enumerate(self.module_names)}
    self.num_vocab_nmn = len(self.module_names)

    self.P, self.W, self.b = _build_validity_mats(self.module_names)

  def module_list2tokens(self, module_list, T=None):
    layout_tokens = [self.name2idx_dict[name] for name in module_list]
    if T is not None:
      if len(module_list) >= T:
        raise ValueError('Not enough time steps to add <eos>')
      layout_tokens += [self.EOS_idx]*(T-len(module_list))
    return layout_tokens

  def _layout_tokens2str(self, layout_tokens):
    return ' '.join([self.module_names[idx] for idx in layout_tokens])

  def assemble_refer(self, text_att, round_id, reuse_stack):
    # aliases
    weaver = self.weaver
    executor = self.executor

    # compute the scores
    logits = []
    for find_arg in reuse_stack:
      # compute the weights for each of the attention map
      inputs = (text_att, find_arg[1], round_id, find_arg[2])
      logits.append(weaver.align_text(*inputs))

    # exponential each logit
    weights = []
    for ii in logits: weights.append(weaver.exp(ii))

    # normalize the weights
    if len(weights) < 2:
      norm = weights[0]
    else:
      norm = weaver.add(weights[0], weights[1])
      for ii in weights[2:]: norm = weaver.add(norm, ii)
    for index, ii in enumerate(weights):
      weights[index] = weaver.divide(ii, norm)

    # multiply the attention with softmax weight
    prev_att = []
    for (att, _, _, _, _), weight in zip(reuse_stack, weights):
      prev_att.append(weaver.weight_attention(att, weight))

    # add all attentions to get the result
    if len(prev_att) < 2: out = prev_att[0]
    else:
      out = weaver.add_attention(prev_att[0], prev_att[1])
      for ii in prev_att[2:]:
        out = weaver.add_attention(out, ii)

    return out, weights, logits

  def assemble_exclude(self, text_att, round_id, reuse_stack):
    # aliases
    weaver = self.weaver
    executor = self.executor

    # compute the scores
    weights = []
    exclude_att = reuse_stack[0][0]
    if len(reuse_stack) > 1:
      for find_arg in reuse_stack:
        exclude_att = weaver.max_attention(exclude_att, find_arg[0])

    return weaver.normalize_exclude(exclude_att)

  # code to check if the program makes sense
  # typically contains all the checks from the _assemble_program method
  def sanity_check_program(self, layout):
    decode_stack = []
    for t_id, cur_op_id in enumerate(layout):
      cur_op_name = self.module_names[cur_op_id]
      # <eos> would mean stop
      if cur_op_id == self.EOS_idx: break

      # insufficient number of inputs
      num_inputs = _module_input_num[cur_op_name]
      if len(decode_stack) < num_inputs:
        return False, 'Insufficient inputs'

      # read the inputs
      inputs = []
      for ii in range(num_inputs):
        arg_type = decode_stack.pop()
        # cannot consume anything but attention
        if arg_type != 'att':
          return False, 'Intermediate not attention'

      decode_stack.append(_module_output_type[cur_op_name])

    # Check if only one element is left
    if len(decode_stack) != 1:
      return False, 'Left with more than one outputs'
    # final output is not answer type
    elif decode_stack[0] != 'ans':
      return False, 'Final output not an answer'

    return True, 'Valid program'

  def assemble(self, layout_tokens, executor, visualize=False):
    # layout_tokens_batch is a numpy array with shape [T, N],
    # containing module tokens and <eos>, in Reverse Polish Notation.

    # internalize executor and weaver
    self.executor = executor
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



