obelics/processors/web_document_extractor.py (371 lines of code) (raw):
import glob
import json
import logging
import math
import os
import tarfile
from copy import deepcopy
import git
from datasets import Dataset, Image, Sequence, Value, concatenate_datasets, load_from_disk
from pathos.multiprocessing import ProcessingPool as Pool
from tqdm import tqdm
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def write_file(path_file, to_write):
f = open(path_file, "w")
f.truncate(0)
f.write(to_write)
f.close()
def html_to_web_documents(
dataset,
dom_tree_simplificator,
pre_extraction_simplificator,
num_proc,
html_column_name="html",
url_column_name="url",
):
def func_html_to_web_documents(example):
html_str = example[html_column_name]
page_url = example[url_column_name]
general_metadata = {}
if all(
[
column_name in example
for column_name in ["url", "warc_filename", "warc_record_offset", "warc_record_length"]
]
):
general_metadata = {
"url": example["url"],
"warc_filename": example["warc_filename"],
"warc_record_offset": example["warc_record_offset"],
"warc_record_length": example["warc_record_length"],
}
try:
selectolax_tree = dom_tree_simplificator(html_str, type_return="selectolax_tree")
list_nodes = pre_extraction_simplificator(selectolax_tree, page_url=page_url)
except Exception:
print("EXCEPTION")
example["texts"] = []
example["images"] = []
example["metadata"] = json.dumps([])
example["general_metadata"] = json.dumps([])
return example
texts = []
images = []
metadata = []
for node in list_nodes:
if node.tag == "-text":
texts.append(node.text)
images.append("")
metadata.append(None)
elif node.tag == "img":
texts.append(None)
images.append(node.media_info["src"])
metadata.append(node.media_info)
example["texts"] = texts
example["images"] = images
example["metadata"] = json.dumps(metadata)
example["general_metadata"] = json.dumps(general_metadata)
return example
logger.info("Starting extracting the documents")
dataset = dataset.map(func_html_to_web_documents, num_proc=num_proc, remove_columns=dataset.column_names)
logger.info("Finished extracting the documents")
return dataset
def get_image_urls(dataset, num_proc, path_save_file_image_urls):
def func_get_image_urls(example):
example["urls"] = [el for el in example["images"] if el]
return example
logger.info("Starting getting the urls of all images")
image_urls = dataset.map(func_get_image_urls, remove_columns=dataset.column_names, num_proc=num_proc)
image_urls = [sub_el for el in image_urls["urls"] for sub_el in el if sub_el]
image_urls = list(set(image_urls))
write_file(path_file=path_save_file_image_urls, to_write="\n".join(image_urls))
logger.info("Finished getting the urls of all images")
def download_images(
path_save_file_image_urls,
path_save_dir_downloaded_images,
number_sample_per_shard,
image_size,
resize_mode,
num_proc,
thread_count,
):
# Before calling this method, set up a DNS solver
# https://github.com/rom1504/img2dataset#setting-up-a-bind9-resolver
logger.info("Starting downloading the images")
os.system(
"img2dataset"
f" --url_list={path_save_file_image_urls} --output_folder={path_save_dir_downloaded_images}"
f" --processes_count={num_proc} --thread_count={thread_count}"
f" --number_sample_per_shard={number_sample_per_shard} --image_size={image_size}"
f" --resize_mode={resize_mode} --output_format=webdataset"
)
logger.info("Finished downloading the images")
def create_dataset_images_from_tar(
tar_paths,
path_save_dir_tmp_datasets_images,
num_proc,
path_save_file_map_url_idx,
path_save_dir_dataset_images,
):
def process_one_tar(args):
(tar_path, idx_tar) = args
with tarfile.open(tar_path) as tar_file:
tar_members = tar_file.getmembers()
name_to_url = {}
name_to_img = {}
url_to_img = {}
for tar_member in tar_members:
if tar_member.name.endswith(".jpg"):
name = tar_member.name.replace(".jpg", "")
tar_member_file = tar_file.extractfile(tar_member)
img = tar_member_file.read()
tar_member_file.close()
name_to_img[name] = img
elif tar_member.name.endswith(".json"):
name = tar_member.name.replace(".json", "")
tar_member_file = tar_file.extractfile(tar_member)
json_val = json.loads(tar_member_file.read())
status = json_val["status"]
url = json_val["url"]
tar_member_file.close()
if status == "success": # Should always happend with webdataset format, not with parquet
name_to_url[name] = url
for name in name_to_url:
url_to_img[name_to_url[name]] = name_to_img[name]
new_urls_indexed = list(url_to_img.keys())
new_datasets_images = Dataset.from_dict(
{"url": list(url_to_img.keys()), "image": list(url_to_img.values())}
)
# We need to save the new datasets and then reload them, since `from_dict` store the dataset
# in the RAM and does not use the disk space
new_datasets_images.save_to_disk(os.path.join(path_save_dir_tmp_datasets_images, str(idx_tar)))
return new_urls_indexed
logger.info("Starting creating the dataset of all images")
args_pool = [(tar_path, idx_tar) for idx_tar, tar_path in enumerate(tar_paths)]
pool = Pool(num_proc)
urls_indexed = pool.map(process_one_tar, args_pool)
urls_indexed = [sub_el for el in urls_indexed for sub_el in el]
map_url_idx = {url: idx for idx, url in enumerate(urls_indexed)}
with open(path_save_file_map_url_idx, "w") as f:
json.dump(map_url_idx, f)
datasets_images = [
load_from_disk(os.path.join(path_save_dir_tmp_datasets_images, str(idx_tar)))
for idx_tar in range(len(tar_paths))
]
dataset_images = concatenate_datasets(datasets_images)
dataset_images.save_to_disk(path_save_dir_dataset_images)
logger.info("Finished creating the dataset of all images")
return dataset_images
def create_dataset_images(
path_save_dir_downloaded_images,
path_save_dir_tmp_datasets_images,
num_proc,
path_save_file_map_url_idx,
path_save_dir_dataset_images,
):
tar_paths = glob.glob(os.path.join(path_save_dir_downloaded_images, "*.tar"))
dataset_images = create_dataset_images_from_tar(
tar_paths=tar_paths,
path_save_dir_tmp_datasets_images=path_save_dir_tmp_datasets_images,
num_proc=num_proc,
path_save_file_map_url_idx=path_save_file_map_url_idx,
path_save_dir_dataset_images=path_save_dir_dataset_images,
)
return dataset_images
def urls_to_images(dataset, dataset_images, map_url_idx, num_proc, some_urls_are_already_retrieved=False):
if some_urls_are_already_retrieved:
if "images_urls" not in dataset.features or "images" not in dataset.features:
raise ValueError(
"If some urls are already retrieved, the dataset must contain the features 'images_urls' and 'images'"
)
def retrieve_image(url):
if url not in map_url_idx:
return None
image = {"path": None, "bytes": dataset_images[map_url_idx[url]]["image"]}
return image
def func_urls_to_images_urls_in_images_col(example):
example["images_urls"] = deepcopy(example["images"])
num_urls = sum([(url is not None and url != "") for url in example["images_urls"]])
example["images"] = [retrieve_image(url) if url else None for url in example["images"]]
num_found = sum([img is not None for img in example["images"]])
num_not_found = num_urls - num_found
example["num_found"] = num_found
example["num_not_found"] = num_not_found
return example
def func_urls_to_images_urls_in_images_urls_col(example):
num_urls = sum([(url is not None and url != "") for url in example["images_urls"]])
example["images"] = [
img if img is not None else retrieve_image(url) if url else None
for img, url in zip(example["images"], example["images_urls"])
]
num_found = sum([img is not None for img in example["images"]])
num_not_found = num_urls - num_found
example["num_found"] = num_found
example["num_not_found"] = num_not_found
return example
func_urls_to_images = (
func_urls_to_images_urls_in_images_urls_col
if some_urls_are_already_retrieved
else func_urls_to_images_urls_in_images_col
)
logger.info("Starting replacing urls by images")
new_features = deepcopy(dataset.features)
new_features["images"] = Sequence(Image())
new_features["images_urls"] = Sequence(Value("string"))
new_features["num_found"] = Value("int32")
new_features["num_not_found"] = Value("int32")
dataset = dataset.map(
func_urls_to_images,
features=new_features,
num_proc=num_proc,
load_from_cache_file=False,
)
logger.info("Finished replacing urls by images")
return dataset
def save_split_sharded_already_splitted_dataset(dataset, path_save_dir_sharded_dataset, shard_size):
def save_split_ds(split_dataset, split_name):
num_shards = math.ceil(len(split_dataset) / shard_size)
for idx in tqdm(range(num_shards)):
shard = split_dataset.shard(num_shards=num_shards, index=idx, contiguous=True)
shard.save_to_disk(os.path.join(path_save_dir_sharded_dataset, split_name, f"shard_{idx}"))
os.makedirs(path_save_dir_sharded_dataset, exist_ok=True)
f = open(os.path.join(path_save_dir_sharded_dataset, "dataset_dict.json"), "w")
f.write('{"splits": ["train", "valid"]}')
f.close()
os.makedirs(os.path.join(path_save_dir_sharded_dataset, "train"), exist_ok=True)
os.makedirs(os.path.join(path_save_dir_sharded_dataset, "valid"), exist_ok=True)
logger.info("Starting sharding the dataset")
train_dataset = dataset["train"]
valid_dataset = dataset["valid"]
save_split_ds(train_dataset, "train")
save_split_ds(valid_dataset, "valid")
logger.info("Finished sharding the dataset")
def save_split_sharded_dataset(dataset, path_save_dir_sharded_dataset, shard_size):
os.makedirs(path_save_dir_sharded_dataset, exist_ok=True)
f = open(os.path.join(path_save_dir_sharded_dataset, "dataset_dict.json"), "w")
f.write('{"splits": ["train", "valid"]}')
f.close()
os.makedirs(os.path.join(path_save_dir_sharded_dataset, "train"), exist_ok=True)
os.makedirs(os.path.join(path_save_dir_sharded_dataset, "valid"), exist_ok=True)
logger.info("Starting sharding the dataset")
num_shards = math.ceil(len(dataset) / shard_size)
for idx in tqdm(range(num_shards)):
shard = dataset.shard(num_shards=num_shards, index=idx, contiguous=True)
if idx < 2:
shard.save_to_disk(os.path.join(path_save_dir_sharded_dataset, "valid", f"shard_{idx}"))
else:
shard.save_to_disk(os.path.join(path_save_dir_sharded_dataset, "train", f"shard_{idx}"))
logger.info("Finished sharding the dataset")
class CommonCrawlWebDocumentExtractor:
def __init__(
self,
html_dataset,
dom_tree_simplificator,
pre_extraction_simplificator,
path_save_dir_dataset,
num_proc,
path_save_file_image_urls,
path_save_dir_downloaded_images,
thread_count,
number_sample_per_shard,
image_size,
resize_mode,
path_save_dir_tmp_datasets_images,
path_save_dir_dataset_images,
path_save_file_map_url_idx,
num_proc_urls_to_images,
path_save_dir_sharded_dataset,
shard_size,
):
self.dataset = html_dataset
self.dom_tree_simplificator = dom_tree_simplificator
self.pre_extraction_simplificator = pre_extraction_simplificator
self.path_save_dir_dataset = path_save_dir_dataset
self.num_proc = num_proc
self.path_save_file_image_urls = path_save_file_image_urls
self.path_save_dir_downloaded_images = path_save_dir_downloaded_images
self.thread_count = thread_count
self.number_sample_per_shard = number_sample_per_shard
self.image_size = image_size
self.resize_mode = resize_mode
self.path_save_dir_tmp_datasets_images = path_save_dir_tmp_datasets_images
self.path_save_dir_dataset_images = path_save_dir_dataset_images
self.path_save_file_map_url_idx = path_save_file_map_url_idx
self.num_proc_urls_to_images = num_proc_urls_to_images
self.path_save_dir_sharded_dataset = path_save_dir_sharded_dataset
self.shard_size = shard_size
def html_to_web_documents(self):
self.dataset = html_to_web_documents(
dataset=self.dataset,
dom_tree_simplificator=self.dom_tree_simplificator,
pre_extraction_simplificator=self.pre_extraction_simplificator,
num_proc=self.num_proc,
)
def get_image_urls(self):
get_image_urls(
dataset=self.dataset, num_proc=self.num_proc, path_save_file_image_urls=self.path_save_file_image_urls
)
def download_images(self):
download_images(
path_save_file_image_urls=self.path_save_file_image_urls,
path_save_dir_downloaded_images=self.path_save_dir_downloaded_images,
number_sample_per_shard=self.number_sample_per_shard,
image_size=self.image_size,
resize_mode=self.resize_mode,
num_proc=self.num_proc,
thread_count=self.thread_count,
)
def create_dataset_images(self):
self.dataset_images = create_dataset_images(
path_save_dir_downloaded_images=self.path_save_dir_downloaded_images,
path_save_dir_tmp_datasets_images=self.path_save_dir_tmp_datasets_images,
num_proc=self.num_proc,
path_save_file_map_url_idx=self.path_save_file_map_url_idx,
path_save_dir_dataset_images=self.path_save_dir_dataset_images,
)
def urls_to_images(self, reload_files=False):
with open(self.path_save_file_map_url_idx) as f:
self.map_url_idx = json.load(f)
# Useful when this method is called independently without
# the previous ones, so we need to load some files
if reload_files:
logger.info("Starting reloading variables for the step urls_to_images")
self.dataset = load_from_disk(self.path_save_dir_dataset)
self.dataset_images = load_from_disk(self.path_save_dir_dataset_images)
logger.info("Finished reloading variables for the step urls_to_images")
else:
try:
_ = self.dataset
_ = self.dataset_images
_ = self.map_url_idx
except Exception:
print("Set `reload_files=True` if you're calling this method alone to define the missing variables")
self.dataset = urls_to_images(
dataset=self.dataset,
dataset_images=self.dataset_images,
map_url_idx=self.map_url_idx,
num_proc=self.num_proc_urls_to_images,
)
def save_dataset(self):
logger.info("Starting saving the dataset")
self.dataset.save_to_disk(self.path_save_dir_dataset, num_proc=self.num_proc)
logger.info("Finished saving the dataset")
def save_commit_hash(self):
logger.info("Starting writing the commit hash")
repo = git.Repo(search_parent_directories=True)
sha = repo.head.object.hexsha
write_file(os.path.join(self.path_save_dir_dataset, "commit_hash.txt"), sha)
logger.info("Finished writing the commit hash")
def save_split_sharded_dataset(self, reload_files=False):
# Useful when this method is called independently without
# the previous ones, so we need to load some files
if reload_files:
self.dataset = load_from_disk(self.path_save_dir_dataset)
else:
try:
_ = self.dataset
except Exception:
print("Set `reload_files=True` if you're calling this method alone to define the missing variables")
save_split_sharded_dataset(
dataset=self.dataset,
path_save_dir_sharded_dataset=self.path_save_dir_sharded_dataset,
shard_size=self.shard_size,
)