in utils/generate_vertical_tabert_training_data.py [0:0]
def generate_for_epoch(table_db: TableDatabase,
indices: List[int],
epoch_file: Path,
input_formatter: TableBertBertInputFormatter,
args: Namespace):
debug_file = epoch_file.with_suffix('.sample.json') if args.is_master else None
if debug_file:
f_dbg = open(debug_file, 'w')
row_data_sequences = []
row_data_offsets = []
mlm_data_sequences = []
mlm_data_offsets = []
def _save_shard():
data = {
'row_data_sequences': np.uint16(row_data_sequences),
'row_data_offsets': np.uint64(row_data_offsets),
'mlm_data_sequences': np.uint16(mlm_data_sequences),
'mlm_data_offsets': np.uint64(mlm_data_offsets),
}
with h5py.File(str(epoch_file), 'w') as f:
for key, val in data.items():
f.create_dataset(key, data=val)
del row_data_sequences[:]
del row_data_offsets[:]
del mlm_data_sequences[:]
del mlm_data_offsets[:]
for example_idx in tqdm(indices, desc=f"Generating dataset {epoch_file}", file=sys.stdout):
example = table_db[example_idx]
try:
instances = input_formatter.get_pretraining_instances_from_example(example, sample_context)
for instance in instances:
if debug_file and random() <= 0.05:
f_dbg.write(json.dumps(instance) + os.linesep)
input_formatter.remove_unecessary_instance_entries(instance)
table_data = []
for row_inst in instance['rows']:
row_data = serialize_row_data(row_inst, config=input_formatter.config)
table_data.extend(row_data)
row_data_offsets.append([
instance['table_size'][0], # row_num
instance['table_size'][1], # column_num
len(row_data_sequences), # start index
len(row_data_sequences) + len(table_data) # end index
])
row_data_sequences.extend(table_data)
s1 = len(mlm_data_sequences)
mlm_data = []
mlm_data.extend(instance['masked_context_token_positions'])
s2 = s1 + len(mlm_data)
mlm_data.extend(instance['masked_context_token_label_ids'])
s3 = s1 + len(mlm_data)
mlm_data.extend(instance['masked_column_token_column_ids'])
s4 = s1 + len(mlm_data)
mlm_data.extend(instance['masked_column_token_label_ids'])
s5 = s1 + len(mlm_data)
mlm_data_offsets.append([s1, s2, s3, s4, s5])
mlm_data_sequences.extend(mlm_data)
except:
# raise
typ, value, tb = sys.exc_info()
print('*' * 50 + 'Exception' + '*' * 50, file=sys.stderr)
print(example.serialize(), file=sys.stderr)
print('*' * 50 + 'Stack Trace' + '*' * 50, file=sys.stderr)
traceback.print_exc(file=sys.stderr)
# print('*' * 50 + 'Exception' + '*' * 50, file=sys.stderr)
sys.stderr.flush()
_save_shard()