create_only_with_pdfs/load_data.py (111 lines of code) (raw):
import os
import re
import pandas as pd
import datasets
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
import argparse
tqdm.pandas(desc="Pandas apply progress")
DATA_PATH = '/fsx/andi/pdfa_data/'
TAR_FILE_PATTERN = 'pdfa-eng-train-{:06d}.tar'
# Function to determine if a string contains code-like structures
def is_valid_question_or_answer(text):
if not text or text.strip() == "":
return False
# Define patterns that indicate code
patterns = [
r'\{.*?\}', # Matches { ... }
r'\[.*?\]', # Matches [ ... ]
r'<.*?>', # Matches < ... >
r'\b\d{1,3}(\.\d{1,3}){3}\b', # Matches IP addresses
r'\w+\.\w+', # Matches word.word patterns
r'\n\s*\n', # Matches two consecutive newlines
r'unanswerable', # Matches 'unanswerable' regardless of case
r'Q\d+: ', # Contains other questions
r'A\d+: ', # Contains other answers
]
return not any(re.search(pattern, text, re.IGNORECASE) for pattern in patterns)
# Function to process a single group
def process_group(key_group):
try:
key, group = key_group
qa_pairs = []
for _, row in group.iterrows():
question = re.sub(r'^Q\d+: ', '', row['question'])
answer = re.sub(r'^A\d+: ', '', row['answer'])
if is_valid_question_or_answer(question) and is_valid_question_or_answer(answer):
qa_pairs.append({
"user": question,
"assistant": answer,
"source": "PDFA key: " + str(row['__key__'])
})
if qa_pairs:
return {
"texts": qa_pairs,
"pdf": group['pdf'].iloc[0]
}
except Exception as e:
print(f"Error processing group {key}: {e}")
return None
def process_tar_index(tar_index, step_size, question_answer_df):
shard_nr = tar_index//step_size
loaded_datasets = []
for inner_idx in range(step_size):
tar_file = os.path.join(DATA_PATH, TAR_FILE_PATTERN.format(tar_index+inner_idx))
try:
print(f"Loading dataset from: {tar_file}")
hf_dataset = datasets.load_dataset('webdataset', split='train', data_files=tar_file, cache_dir="/fsx/.cache").to_pandas()
hf_dataset.__key__ = hf_dataset.__key__.apply(pd.to_numeric)
loaded_datasets.append(hf_dataset)
except Exception as e:
print(f"Error loading dataset from: {tar_file}")
print(e)
hf_dataset = pd.concat(loaded_datasets, ignore_index=True)
print(f"Concatenated datasets with {len(hf_dataset)} samples")
hf_dataset = hf_dataset[hf_dataset['__key__'].isin(question_answer_df['__key__'].unique())] # Filter samples that are not present in question_answer_df
# Merging dataframes on '__key__' column
merged_df = pd.merge(hf_dataset, question_answer_df, on='__key__', how='inner')
# Using ThreadPoolExecutor for parallel processing of groups
data_extracted = []
max_threads = 10 # Number of threads to use
with ThreadPoolExecutor(max_threads) as executor:
results = list(tqdm(executor.map(process_group, merged_df.groupby('__key__')), desc='Extracting data', total=len(merged_df['__key__'].unique())))
data_extracted.extend(results)
data_extracted = list(filter(lambda item: item is not None, data_extracted)) # Filter out None values
FEATURES = datasets.Features(
{
"pdf": datasets.Value("binary"),
"texts": [
{
"user": datasets.Value("string"),
"assistant": datasets.Value("string"),
"source": datasets.Value("string"),
}
],
}
)
def data_generator():
for data_dict in data_extracted:
yield data_dict
#
ds_shard = datasets.Dataset.from_generator(data_generator, features=FEATURES, writer_batch_size=100, cache_dir="/fsx/.cache")
ds_shard.save_to_disk(f'/fsx/m4/datasets/docmatix_pdf/shard_{shard_nr}')
def load_and_concatenate_dataframes():
if os.path.exists('/fsx/andi/llm-swarm/concatenated_synthetic_dataset.parquet.gzip'):
return pd.read_parquet('/fsx/andi/llm-swarm/concatenated_synthetic_dataset.parquet.gzip')
# Directory where the .h5 files are stored
directory = '.'
# List all files in the directory
all_files = os.listdir(directory)
# Filter out the .h5 files and sort them
h5_files = sorted([f for f in all_files if re.match(r'synthetic_dataset_batch_\d+\.h5$', f)])
# Initialize an empty list to hold the dataframes
dataframes = []
# Load each .h5 file and append the dataframe to the list
for file in tqdm(h5_files, desc="Loading data"):
file_path = os.path.join(directory, file)
df = pd.read_hdf(file_path)
if '__key__' not in df.columns:
raise ValueError(f"Key column not found in {file_path}")
df.__key__ = df.__key__.apply(pd.to_numeric)
dataframes.append(df)
# Concatenate all dataframes
concatenated_df = pd.concat(dataframes, ignore_index=True)
concatenated_df.to_parquet('concatenated_synthetic_dataset.parquet.gzip', compression='gzip')
return concatenated_df
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Process .h5 files and tar indices.")
parser.add_argument('--start_index', type=int, default=0, help='The starting index for tar processing.')
parser.add_argument('--step_size', type=int, default=1, help='The step size for tar processing.')
args = parser.parse_args()
question_answer_df = load_and_concatenate_dataframes()
print(len(question_answer_df))
process_tar_index(args.start_index, args.step_size, question_answer_df=question_answer_df)