in utils/generate_vanilla_tabert_training_data.py [0:0]
def generate_for_epoch(table_db: TableDatabase,
indices: List[int],
epoch_file: Path,
input_formatter: TableBertBertInputFormatter,
args: Namespace):
debug_file = epoch_file.with_suffix('.sample.json') if args.is_master else None
if debug_file:
f_dbg = open(debug_file, 'w')
sequences = []
segment_a_lengths = []
sequence_offsets = []
masked_lm_positions = []
masked_lm_label_ids = []
masked_lm_offsets = []
def _save_shard():
data = {
'sequences': np.uint16(sequences),
'segment_a_lengths': np.uint16(segment_a_lengths),
'sequence_offsets': np.uint64(sequence_offsets),
'masked_lm_positions': np.uint16(masked_lm_positions),
'masked_lm_label_ids': np.uint16(masked_lm_label_ids),
'masked_lm_offsets': np.uint64(masked_lm_offsets)
}
with h5py.File(str(epoch_file), 'w') as f:
for key, val in data.items():
f.create_dataset(key, data=val)
del sequences[:]
del segment_a_lengths[:]
del sequence_offsets[:]
del masked_lm_positions[:]
del masked_lm_label_ids[:]
del masked_lm_offsets[:]
for example_idx in tqdm(indices, desc=f"Generating dataset {epoch_file}", file=sys.stdout):
example = table_db[example_idx]
try:
instances = input_formatter.get_pretraining_instances_from_example(example, sample_context)
for instance in instances:
if debug_file and random() <= 0.05:
f_dbg.write(json.dumps(instance) + os.linesep)
input_formatter.remove_unecessary_instance_entries(instance)
cur_pos = len(sequences)
sequence_len = len(instance['token_ids'])
sequences.extend(instance['token_ids'])
segment_a_lengths.append(instance['segment_a_length'])
sequence_offsets.append([cur_pos, cur_pos + sequence_len])
cur_pos = len(masked_lm_positions)
lm_mask_len = len(instance['masked_lm_positions'])
masked_lm_positions.extend(instance['masked_lm_positions'])
masked_lm_label_ids.extend(instance['masked_lm_label_ids'])
masked_lm_offsets.append([cur_pos, cur_pos + lm_mask_len])
except:
# raise
typ, value, tb = sys.exc_info()
print('*' * 50 + 'Exception' + '*' * 50, file=sys.stderr)
print(example.serialize(), file=sys.stderr)
print('*' * 50 + 'Stack Trace' + '*' * 50, file=sys.stderr)
traceback.print_exc(file=sys.stderr)
# print('*' * 50 + 'Exception' + '*' * 50, file=sys.stderr)
sys.stderr.flush()
_save_shard()