in clean_and_create/load_data.py [0:0]
def process_tar_index(tar_index, step_size, question_answer_df):
shard_nr = tar_index//step_size
loaded_datasets = []
for inner_idx in range(step_size):
tar_file = os.path.join(DATA_PATH, TAR_FILE_PATTERN.format(tar_index+inner_idx))
try:
print(f"Loading dataset from: {tar_file}")
hf_dataset = datasets.load_dataset('webdataset', split='train', data_files=tar_file, cache_dir="/fsx/.cache").to_pandas()
hf_dataset.__key__ = hf_dataset.__key__.apply(pd.to_numeric)
loaded_datasets.append(hf_dataset)
except Exception as e:
print(f"Error loading dataset from: {tar_file}")
print(e)
hf_dataset = pd.concat(loaded_datasets, ignore_index=True)
print(f"Concatenated datasets with {len(hf_dataset)} samples")
hf_dataset = hf_dataset[hf_dataset['__key__'].isin(question_answer_df['__key__'].unique())] # Filter samples that are not present in q_a_df
df_data = pd.DataFrame({'key': []})
if os.path.exists(f"/fsx/m4/datasets/large_docvqa/shard_{shard_nr}"):
print('using saved data')
df_data = datasets.load_from_disk(f"/fsx/m4/datasets/large_docvqa/shard_{shard_nr}").to_pandas()
df_data["__key__"] = df_data.texts.apply(lambda x: x[0]['source'].split('_')[1])
df_data["__key__"] = df_data["__key__"].apply(pd.to_numeric)
df_data.drop(columns=['texts'], inplace=True)
hf_dataset = hf_dataset[hf_dataset['__key__'].isin(df_data['__key__'].unique())] # Filter out samples that failed conversion
hf_dataset = pd.merge(hf_dataset, df_data, on='__key__', how='inner')
hf_dataset['pdf'] = hf_dataset['images']
hf_dataset.drop(columns=['images'], inplace=True)
del df_data
else:
hf_dataset['pdf'] = hf_dataset['pdf'].progress_apply(lambda x: process_images(x)) # Decode pdf pages in place to save memory
hf_dataset = hf_dataset[~hf_dataset['pdf'].isnull()] # Filter out images that failed
# Merging dataframes on '__key__' column
merged_df = pd.merge(hf_dataset, question_answer_df, on='__key__', how='inner')
# Using ThreadPoolExecutor for parallel processing of groups
data_extracted = []
max_threads = 10 # Number of threads to use
with ThreadPoolExecutor(max_threads) as executor:
results = list(tqdm(executor.map(process_group, merged_df.groupby('__key__')), desc='Extracting data', total=len(merged_df['__key__'].unique())))
data_extracted.extend(results)
data_extracted = list(filter(lambda item: item is not None, data_extracted)) # Filter out None values
FEATURES = datasets.Features(
{
"images": datasets.Sequence(datasets.Image(decode=True)),
"texts": [
{
"user": datasets.Value("string"),
"assistant": datasets.Value("string"),
"source": datasets.Value("string"),
}
],
}
)
def data_generator():
for data_dict in data_extracted:
yield data_dict
#
ds_shard = datasets.Dataset.from_generator(data_generator, features=FEATURES, writer_batch_size=100, cache_dir="/fsx/.cache")
ds_shard.save_to_disk(f'/fsx/m4/datasets/docvqa_instruct/shard_{shard_nr}')