def __init__()

in relogic/pretrainkit/datasets/semparse/rat_text2sql.py [0:0]


  def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, local_rank=-1):
    self.examples = []
    self.keywords = label_mapping["keyword"]
    self.label_eos_id = self.keywords.index(label_mapping["label_eos_token"])
    self.label_bos_id = self.keywords.index(label_mapping["label_bos_token"])
    add_prefix_space = isinstance(tokenizer, BartTokenizer) or isinstance(tokenizer, RobertaTokenizer)
    total, valid = 0, 0
    with open(file_path, encoding="utf-8") as f:
      for line in tqdm(f):
        total += 1
        example = json.loads(line)
        text = example["normalized_question"]
        columns = example["columns"]
        tables = example["tables"]
        columns_text = example["column_text"]
        tables_text = example["table_text"]
        sql = example["sql"]
        # we need the adjusted token index info.
        token_idx_to_sub_token_start_idx = {}
        text_tokens = [tokenizer.cls_token]
        start_idx = 0 # This is for adjusting the sc_link and cv_link
        for idx, token in enumerate(text.split()):
          sub_tokens = tokenizer.tokenize(token, add_prefix_space=add_prefix_space)
          token_idx_to_sub_token_start_idx[idx] = start_idx
          text_tokens.extend(sub_tokens)
          start_idx += len(sub_tokens)
        text_tokens.append(tokenizer.sep_token)
        question_start, question_end = 1, len(text_tokens) - 1 # exclusive

        column_spans = []
        start_idx = len(text_tokens)
        for column_tokens in columns_text:
          column_str = " ".join(column_tokens)
          column_tokens = tokenizer.tokenize(column_str, add_prefix_space=add_prefix_space)
          text_tokens.extend(column_tokens)
          text_tokens.append(tokenizer.sep_token)
          end_idx = start_idx + len(column_tokens)
          column_spans.append((start_idx, end_idx))
          start_idx = end_idx + 1

        column_start = [column_span[0] for column_span in column_spans]
        column_end = [column_span[1] for column_span in column_spans]

        table_spans = []
        start_idx = len(text_tokens)
        for table_tokens in tables_text:
          table_str = " ".join(table_tokens)
          table_tokens = tokenizer.tokenize(table_str, add_prefix_space=add_prefix_space)
          text_tokens.extend(table_tokens)
          text_tokens.append(tokenizer.sep_token)
          end_idx = start_idx + len(table_tokens)
          table_spans.append((start_idx, end_idx))
          start_idx = end_idx + 1

        table_start = [table_span[0] for table_span in table_spans]
        table_end = [table_span[1] for table_span in table_spans]

        input_ids = tokenizer.convert_tokens_to_ids(text_tokens)

        if len(input_ids) > block_size:
          continue

        label_ids = []
        try:
          for token in sql.split():
            if token in columns:
              label_ids.append(columns.index(token) + len(self.keywords))
            else:
              label_ids.append(self.keywords.index(token))
        except:
          continue

        label_ids = [self.label_bos_id] + label_ids + [self.label_eos_id]

        primary_key = [int(x) for x in example["sc_struct"]["primary_key"]]
        foreign_key = {x.split(",")[0]: int(x.split(",")[1]) for x in example["sc_struct"]["foreign_key"]}
        column_to_table = {"0": None}

        sc_link = {"q_col_match": {}, "q_tab_match": {}}
        for k, v in example["sc_link"]["q_col_match"].items():
          new_k = str(token_idx_to_sub_token_start_idx[int(k.split(",")[0])]) + "," + k.split(",")[1]
          sc_link["q_col_match"][new_k] = v

        for k, v in example["sc_link"]["q_tab_match"].items():
          new_k = str(token_idx_to_sub_token_start_idx[int(k.split(",")[0])]) + "," + k.split(",")[1]
          sc_link["q_tab_match"][new_k] = v

        cv_link = {"num_date_match": {}, "cell_match": {}}
        for k, v in example["cv_link"]["num_date_match"].items():
          new_k = str(token_idx_to_sub_token_start_idx[int(k.split(",")[0])]) + "," + k.split(",")[1]
          cv_link["num_date_match"][new_k] = v
        for k, v in example["cv_link"]["cell_match"].items():
          new_k = str(token_idx_to_sub_token_start_idx[int(k.split(",")[0])]) + "," + k.split(",")[1]
          cv_link["cell_match"][new_k] = v


        for idx, column in enumerate(columns):
          if column == "*":
            continue
          t = column.split(".")[0]
          column_to_table[str(idx)] = tables.index(t)

        foreign_keys_tables = {}
        for k, v in foreign_key.items():
          t_k = str(column_to_table[str(k)])
          t_v = str(column_to_table[str(v)])
          if t_k not in foreign_keys_tables:
            foreign_keys_tables[t_k] = []
          if int(t_v) not in foreign_keys_tables[t_k]:
            foreign_keys_tables[t_k].append(int(t_v))

        self.examples.append({
          "input_ids": input_ids,
          "example_info": {
            "normalized_question": text,
            "columns": columns,
            "tables": tables,
            "tokens": text_tokens,
            "question_start": question_start,
            "question_end": question_end,
            "column_start": torch.LongTensor(column_start),
            "column_end": torch.LongTensor(column_end),
            "table_start": torch.LongTensor(table_start),
            "table_end": torch.LongTensor(table_end),
            "sc_link": sc_link,
            "cv_link": cv_link,
            "primary_keys": primary_key,
            "foreign_keys": foreign_key,
            "column_to_table": column_to_table,
            "foreign_keys_tables": foreign_keys_tables
          },
          "column_spans": column_spans,
          "label_ids": label_ids})
        valid += 1
    print("Valid Example {}; Invalid Example {}".format(valid, total - valid))