def generate_for_epoch()

in utils/generate_vertical_tabert_training_data.py [0:0]


def generate_for_epoch(table_db: TableDatabase,
                       indices: List[int],
                       epoch_file: Path,
                       input_formatter: TableBertBertInputFormatter,
                       args: Namespace):
    debug_file = epoch_file.with_suffix('.sample.json') if args.is_master else None
    if debug_file:
        f_dbg = open(debug_file, 'w')

    row_data_sequences = []
    row_data_offsets = []
    mlm_data_sequences = []
    mlm_data_offsets = []

    def _save_shard():
        data = {
            'row_data_sequences': np.uint16(row_data_sequences),
            'row_data_offsets': np.uint64(row_data_offsets),
            'mlm_data_sequences': np.uint16(mlm_data_sequences),
            'mlm_data_offsets': np.uint64(mlm_data_offsets),
        }

        with h5py.File(str(epoch_file), 'w') as f:
            for key, val in data.items():
                f.create_dataset(key, data=val)

        del row_data_sequences[:]
        del row_data_offsets[:]
        del mlm_data_sequences[:]
        del mlm_data_offsets[:]

    for example_idx in tqdm(indices, desc=f"Generating dataset {epoch_file}", file=sys.stdout):
        example = table_db[example_idx]
        try:
            instances = input_formatter.get_pretraining_instances_from_example(example, sample_context)

            for instance in instances:
                if debug_file and random() <= 0.05:
                    f_dbg.write(json.dumps(instance) + os.linesep)

                input_formatter.remove_unecessary_instance_entries(instance)

                table_data = []
                for row_inst in instance['rows']:
                    row_data = serialize_row_data(row_inst, config=input_formatter.config)
                    table_data.extend(row_data)

                row_data_offsets.append([
                    instance['table_size'][0],  # row_num
                    instance['table_size'][1],  # column_num
                    len(row_data_sequences),  # start index
                    len(row_data_sequences) + len(table_data)  # end index
                ])
                row_data_sequences.extend(table_data)

                s1 = len(mlm_data_sequences)
                mlm_data = []

                mlm_data.extend(instance['masked_context_token_positions'])
                s2 = s1 + len(mlm_data)

                mlm_data.extend(instance['masked_context_token_label_ids'])
                s3 = s1 + len(mlm_data)

                mlm_data.extend(instance['masked_column_token_column_ids'])
                s4 = s1 + len(mlm_data)

                mlm_data.extend(instance['masked_column_token_label_ids'])
                s5 = s1 + len(mlm_data)

                mlm_data_offsets.append([s1, s2, s3, s4, s5])
                mlm_data_sequences.extend(mlm_data)
        except:
            # raise
            typ, value, tb = sys.exc_info()
            print('*' * 50 + 'Exception' + '*' * 50, file=sys.stderr)
            print(example.serialize(), file=sys.stderr)
            print('*' * 50 + 'Stack Trace' + '*' * 50, file=sys.stderr)
            traceback.print_exc(file=sys.stderr)
            # print('*' * 50 + 'Exception' + '*' * 50, file=sys.stderr)

            sys.stderr.flush()

    _save_shard()