def featurize_each_example()

in research/a2n/dataset.py [0:0]


  def featurize_each_example(self, example_tuple):
    """Convert each example into padded arrays for input to model."""
    s, r, t, reverse = example_tuple
    if not reverse:
      all_targets = self.train_graph.all_reachable_e2[(s, r)]
      if self.mode != "train":
        # add all correct candidate from val/test set
        all_targets |= self.data_graph.all_reachable_e2[(s, r)]
        if self.val_graph:
          # if provided also remove val tuples for testing
          all_targets |= self.val_graph.all_reachable_e2[(s, r)]
    else:
      all_targets = self.train_graph.all_reachable_e2_reverse[(t, r)]
      if self.mode != "train":
        # add all correct candidate from val/test set
        all_targets |= self.data_graph.all_reachable_e2_reverse[(t, r)]
        if self.val_graph:
          # if provided also remove val tuples for testing
          all_targets |= self.val_graph.all_reachable_e2[(s, r)]
        # switch s and t
        s, t = t, s
    candidate_negatives = list(
        self.train_graph.all_entities -
        (all_targets | set([t]) | set([self.train_graph.ent_pad]))
    )
    # if len(candidate_negatives) > self.max_negatives:
    #   negatives = np.random.choice(candidate_negatives,
    #                                size=self.max_negatives,
    #                                replace=False)
    # else:
    #   negatives = np.array(candidate_negatives)
    negatives = sample_or_pad(
        np.array(candidate_negatives, dtype=np.int), self.max_negatives,
        pad_value=self.train_graph.ent_pad
    )
    # negatives is an array of shape (max_negatives)
    # candidates will have shape (max_negatives + 1), i.e including the target
    candidates = np.insert(negatives, 0, t, axis=0)

    if self.model_type == "source_rel_attention":
      nbrhd_fn = get_graph_nbrhd_with_rels
      pad_value = [self.train_graph.rel_pad, self.train_graph.ent_pad]
    elif self.model_type == "source_path_attention":
      # nbrhd_fn = get_graph_nbrhd_paths
      nbrhd_fn = lambda x, y, z: get_graph_nbrhd_paths_randwalk(
          x, y, z, max_length=self.train_graph.max_path_length,
          max_paths=self.max_neighbors, terminate_prob=0.1,
          pad=(self.train_graph.rel_pad, self.train_graph.ent_pad)
      )
      pad_value = [self.train_graph.rel_pad, self.train_graph.ent_pad] * \
        self.train_graph.max_path_length
    else:
      nbrhd_fn = get_graph_nbrhd
      pad_value = self.train_graph.ent_pad
    if self.model_type == "distmult":
      nbrs_s = np.array([], dtype=np.int)
      nbrs_candidates = np.array([], dtype=np.int)
    elif self.model_type in ["source_attention", "source_rel_attention",
                             "source_path_attention"]:
      nbrs_s = sample_or_pad(nbrhd_fn(self.train_graph, s, (s, r, t)),
                             self.max_neighbors,
                             pad_value=pad_value)
      if isinstance(self.train_graph, clueweb_text_graph.CWTextGraph):
        # this func does paddding in there
        text_nbrs_s, text_nbrs_s_emb = get_graph_nbrhd_embd_text(
            self.train_graph, s, self.max_text_neighbors)
      elif self.max_text_len:
        text_pad_value = [self.train_graph.ent_pad] + \
              [self.train_graph.vocab[self.train_graph.mask_token]] * \
              self.max_text_len
        text_nbrs_s = sample_or_pad(
            get_graph_nbrhd_text(self.train_graph, s, self.max_text_len),
            self.max_text_neighbors, pad_value=text_pad_value
        )
      nbrs_candidates = np.array([], dtype=np.int)
    else:
      nbrs_s = sample_or_pad(nbrhd_fn(self.train_graph, s, (s, r, t)),
                             self.max_neighbors,
                             pad_value=pad_value)
      nbrs_t = sample_or_pad(nbrhd_fn(self.train_graph, t, (s, r, t)),
                             self.max_neighbors,
                             pad_value=pad_value)
      nbrs_negatives = np.array(
          [sample_or_pad(nbrhd_fn(self.train_graph, cand, (s, r, t)),
                         self.max_neighbors,
                         pad_value=pad_value)
           for cand in negatives]
      )
      # import pdb; pdb.set_trace()
      nbrs_candidates = np.concatenate(
          (np.expand_dims(nbrs_t, 0), nbrs_negatives), axis=0
      )
    if self.mode != "train":
      labels = [t]
    else:
      labels = np.zeros(candidates.shape[0], dtype=np.int)
      labels[0] = 1
      idx = np.arange(candidates.shape[0])
      np.random.shuffle(idx)
      candidates = candidates[idx]
      if self.model_type == "attention":
        nbrs_candidates = nbrs_candidates[idx]
      labels = labels[idx]
    # import ipdb; ipdb.set_trace()
    if isinstance(self.train_graph, clueweb_text_graph.CWTextGraph):
      return s, nbrs_s, text_nbrs_s, r, candidates, nbrs_candidates, labels, \
             text_nbrs_s_emb
    elif self.max_text_len:
      return s, nbrs_s, text_nbrs_s, r, candidates, nbrs_candidates, labels
    return s, nbrs_s, r, candidates, nbrs_candidates, labels