training/data.py (499 lines of code) (raw):
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is heavily inspired by https://github.com/mlfoundations/open_clip/blob/main/src/training/data.py
import os
import io
import itertools
import json
import math
import random
import re
from functools import partial
from typing import List, Optional, Union
import PIL
import webdataset as wds
import yaml
from braceexpand import braceexpand
from torch.utils.data import default_collate
from torchvision import transforms
from transformers import PreTrainedTokenizer
from webdataset.tariterators import (
base_plus_ext,
tar_file_expander,
url_opener,
valid_sample,
)
person_token = ["a person", "someone", "somebody"]
def replace_person_token(t):
"Used for CC12M"
t = re.sub("<person>([,\s]*(and)*[,\s]*<person>)+", " people ", t)
while "<person>" in t:
t = t.replace("<person>", f" {random.choices(person_token)} ", 1)
return t
def filter_keys(key_set):
def _f(dictionary):
return {k: v for k, v in dictionary.items() if k in key_set}
return _f
def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
"""Return function over iterator that groups key, value pairs into samples.
:param keys: function that splits the key into key and extension (base_plus_ext)
:param lcase: convert suffixes to lower case (Default value = True)
"""
current_sample = None
for filesample in data:
assert isinstance(filesample, dict)
fname, value = filesample["fname"], filesample["data"]
prefix, suffix = keys(fname)
if prefix is None:
continue
if lcase:
suffix = suffix.lower()
# FIXME webdataset version throws if suffix in current_sample, but we have a potential for
# this happening in the current LAION400m dataset if a tar ends with same prefix as the next
# begins, rare, but can happen since prefix aren't unique across tar files in that dataset
if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
if valid_sample(current_sample):
yield current_sample
current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
if suffixes is None or suffix in suffixes:
current_sample[suffix] = value
if valid_sample(current_sample):
yield current_sample
def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue):
# NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
streams = url_opener(src, handler=handler)
files = tar_file_expander(streams, handler=handler)
samples = group_by_keys_nothrow(files, handler=handler)
return samples
def get_orig_size(json):
return (int(json.get("original_width", 0.0)), int(json.get("original_height", 0.0)))
def get_aesthetic_score(json):
if "aesthetic" in json:
a = json["aesthetic"]
elif "AESTHETIC_SCORE" in json:
a = json["AESTHETIC_SCORE"]
elif "aesthetic_score_laion_v2" in json:
a = json["aesthetic_score_laion_v2"]
elif "stability_metadata" in json and "aes_scorelv2" in json["stability_metadata"]:
a = json["stability_metadata"]["aes_scorelv2"]
else:
a = 0.0
a = float(a)
return a
class ImageNetTransform:
def __init__(self, resolution, center_crop=True, random_flip=False):
self.train_transform = transforms.Compose(
[
transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
(transforms.CenterCrop(resolution) if center_crop else transforms.RandomCrop(resolution)),
transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
transforms.ToTensor(),
]
)
self.eval_transform = transforms.Compose(
[
transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(resolution),
transforms.ToTensor(),
]
)
def image_transform(example, resolution=256):
image = example["image"]
image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR)(image)
# get crop coordinates
c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
image = transforms.functional.crop(image, c_top, c_left, resolution, resolution)
image = transforms.ToTensor()(image)
example["image"] = image
example["crop_coords"] = (c_top, c_left)
return example
class ClassificationDataset:
def __init__(
self,
train_shards_path_or_url: Union[str, List[str]],
eval_shards_path_or_url: Union[str, List[str]],
num_train_examples: int,
per_gpu_batch_size: int,
global_batch_size: int,
num_workers: int,
resolution: int = 256,
return_text: bool = False,
tokenizer: PreTrainedTokenizer = None,
max_seq_length: int = 16,
center_crop: bool = True,
random_flip: bool = False,
imagenet_class_mapping_path=None,
shuffle_buffer_size: int = 1000,
pin_memory: bool = False,
persistent_workers: bool = False,
**kwargs,
):
transform = ImageNetTransform(resolution, center_crop, random_flip)
if return_text:
if imagenet_class_mapping_path is None:
raise ValueError("imagenet_class_mapping_path must be provided when return_text is True")
with open(imagenet_class_mapping_path, "r") as f:
self.class_mapping = json.load(f)
def tokenize(imagenet_class_id):
text = self.class_mapping[str(imagenet_class_id)]
input_ids = tokenizer(
text, max_length=max_seq_length, padding="max_length", truncation=True, return_tensors="pt"
).input_ids
return input_ids[0]
processing_pipeline = [
wds.rename(
image="jpg;png;jpeg;webp",
input_ids="cls",
text_raw="cls",
class_id="cls",
handler=wds.warn_and_continue,
),
wds.map(filter_keys(set(["image", "input_ids", "text_raw", "class_idx"]))),
wds.map_dict(
image=transform.train_transform,
input_ids=tokenize,
text_raw=lambda class_idx: self.class_mapping[str(class_idx)],
),
wds.to_tuple("image", "input_ids"),
]
else:
processing_pipeline = [
wds.rename(image="jpg;png;jpeg;webp", class_id="cls", handler=wds.warn_and_continue),
wds.map(filter_keys(set(["image", "class_id"]))),
wds.map_dict(image=transform.train_transform, class_id=lambda x: int(x)),
wds.to_tuple("image", "class_id"),
]
# Create train dataset and loader
pipeline = [
wds.ResampledShards(train_shards_path_or_url),
wds.tarfile_to_samples(handler=wds.ignore_and_continue),
wds.shuffle(shuffle_buffer_size),
wds.decode("pil", handler=wds.ignore_and_continue),
*processing_pipeline,
wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
]
num_batches = math.ceil(num_train_examples / global_batch_size)
num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker
num_batches = num_worker_batches * num_workers
num_samples = num_batches * global_batch_size
# each worker is iterating over this
self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
self._train_dataloader = wds.WebLoader(
self._train_dataset,
batch_size=None,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
)
# add meta-data to dataloader instance for convenience
self._train_dataloader.num_batches = num_batches
self._train_dataloader.num_samples = num_samples
# Create eval dataset and loader
pipeline = [
wds.SimpleShardList(eval_shards_path_or_url),
wds.split_by_worker,
wds.tarfile_to_samples(handler=wds.ignore_and_continue),
wds.decode("pil", handler=wds.ignore_and_continue),
*processing_pipeline,
wds.batched(per_gpu_batch_size, partial=True, collation_fn=default_collate),
]
self._eval_dataset = wds.DataPipeline(*pipeline)
self._eval_dataloader = wds.WebLoader(
self._eval_dataset,
batch_size=None,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
)
@property
def train_dataset(self):
return self._train_dataset
@property
def train_dataloader(self):
return self._train_dataloader
@property
def eval_dataset(self):
return self._eval_dataset
@property
def eval_dataloader(self):
return self._eval_dataloader
class WebdatasetSelect:
def __init__(
self,
min_size=256,
max_pwatermark=0.5,
min_aesthetic_score=4.9,
require_marked_as_ok_by_spawning=False,
require_marked_as_not_getty=False,
max_pnsfw=None,
):
self.min_size = min_size
self.max_pwatermark = max_pwatermark
self.min_aesthetic_score = min_aesthetic_score
self.require_marked_as_ok_by_spawning = require_marked_as_ok_by_spawning
self.require_marked_as_not_getty = require_marked_as_not_getty
self.max_pnsfw = max_pnsfw
def __call__(self, x):
if "json" not in x:
return False
try:
x_json = json.loads(x["json"])
except:
return False
# For all requirements, if the necessary key(s) are not present, we assume
# the requirement does not hold. Note that many checks are done on different keys
# which is due to different datasets being used with different metadata dicts.
# size
if "original_width" not in x_json or "original_height" not in x_json:
return False
original_width = x_json["original_width"]
original_height = x_json["original_height"]
is_less_than_min_size = original_width < self.min_size or original_height < self.min_size
if is_less_than_min_size:
return False
# watermark
if (
("pwatermark" not in x_json or x_json["pwatermark"] is None)
and "watermark_score" not in x_json
and ("stability_metadata" not in x_json or "p_watermarkdf" not in x_json["stability_metadata"])
):
return False
if "pwatermark" in x_json and x_json["pwatermark"] is not None:
is_watermarked = x_json["pwatermark"] > self.max_pwatermark
if is_watermarked:
return False
if "watermark_score" in x_json:
is_watermarked_coyo = x_json["watermark_score"] > self.max_pwatermark
if is_watermarked_coyo:
return False
if "stability_metadata" in x_json and "p_watermarkdf" in x_json["stability_metadata"]:
is_watermarked_stability_metadata = x_json["stability_metadata"]["p_watermarkdf"] > self.max_pwatermark
if is_watermarked_stability_metadata:
return False
# aesthetic
if (
"aesthetic" not in x_json
and "AESTHETIC_SCORE" not in x_json
and "aesthetic_score_laion_v2" not in x_json
and ("stability_metadata" not in x_json or "aes_scorelv2" not in x_json["stability_metadata"])
):
return False
if "aesthetic" in x_json:
is_under_min_aesthetic_threshold = x_json["aesthetic"] < self.min_aesthetic_score
if is_under_min_aesthetic_threshold:
return False
if "AESTHETIC_SCORE" in x_json:
is_under_min_aesthetic_threshold_b = x_json["AESTHETIC_SCORE"] < self.min_aesthetic_score
if is_under_min_aesthetic_threshold_b:
return False
if "aesthetic_score_laion_v2" in x_json:
is_under_min_aesthetic_threshold_coyo = x_json["aesthetic_score_laion_v2"] < self.min_aesthetic_score
if is_under_min_aesthetic_threshold_coyo:
return False
if "stability_metadata" in x_json and "aes_scorelv2" in x_json["stability_metadata"]:
is_under_min_aesthetic_threshold_stability_metadata = (
x_json["stability_metadata"]["aes_scorelv2"] < self.min_aesthetic_score
)
if is_under_min_aesthetic_threshold_stability_metadata:
return False
# spawning
if self.require_marked_as_ok_by_spawning:
if "stability_metadata" not in x_json or "is_spawning" not in x_json["stability_metadata"]:
return False
is_marked_as_not_ok_by_spawning = x_json["stability_metadata"]["is_spawning"]
if is_marked_as_not_ok_by_spawning:
return False
# getty
if self.require_marked_as_not_getty:
if "stability_metadata" not in x_json or "is_getty" not in x_json["stability_metadata"]:
return False
is_marked_as_getty = x_json["stability_metadata"]["is_getty"]
if is_marked_as_getty:
return False
# nsfw
if self.max_pnsfw is not None:
if "stability_metadata" not in x_json or "p_nsfwdf" not in x_json["stability_metadata"]:
return False
is_above_max_nsfw = x_json["stability_metadata"]["p_nsfwdf"] > self.max_pnsfw
if is_above_max_nsfw:
return False
return True
def sdxl_synthetic_dataset_map(sample):
clip_scores = sample["clip_scores.txt"].decode("utf-8")
clip_scores = clip_scores.split(",")
clip_scores = [float(x) for x in clip_scores]
index_of_max = 0
for i in range(1, len(clip_scores)):
if clip_scores[i] > clip_scores[index_of_max]:
index_of_max = i
key_of_best_clip_score_image = f"{index_of_max}.png"
if key_of_best_clip_score_image not in sample:
raise ValueError(
f"{key_of_best_clip_score_image} was not found in sample. The dataset should have files <sample"
" key>.<x>.png where <x> coresponds to an index of the clip scores in clip_scores.txt"
)
return {
"__key__": sample["__key__"],
"__url__": sample["__url__"],
"txt": sample["txt"],
"png": sample[key_of_best_clip_score_image], # only include the image with the best clip score
# For other datasets, we rely on the following for micro conditioning.
# The original height and width are known because we create the dataset with
# sdxl. The laion aesthetic score of 5 seems like a reasonable approximation
# NOTE: we unfortunately have to serialize and encode the json so it looks like
# it was read out of a file since wds decoders will need to decode it. There
# is probably some way to avoid this but it is not obvious with the wds apis.
"json": json.dumps({"aesthetic": 5, "original_width": 1024, "original_height": 1024}).encode(),
}
def ds_clean_upscaled_map(sample):
with io.BytesIO(sample["png"]) as stream:
image = PIL.Image.open(stream)
image.load()
return {
"__key__": sample["__key__"],
"__url__": sample["__url__"],
"txt": sample["txt"],
"png": sample["png"],
"json": json.dumps({"aesthetic": 5, "original_width": image.width, "original_height": image.height}).encode(),
}
def ds_clean_map(sample):
with io.BytesIO(sample["png"]) as stream:
image = PIL.Image.open(stream)
image.load()
# Take only the top left image
height = image.height // 2
width = image.width // 2
image = image.crop((0, 0, width, height))
image_bytes = io.BytesIO()
image.save(image_bytes, format="PNG") # You can specify the desired format (e.g., JPEG)
image = image_bytes.getvalue()
return {
"__key__": sample["__key__"],
"__url__": sample["__url__"],
"txt": sample["txt"],
"png": image,
"json": json.dumps({"aesthetic": 5, "original_width": width, "original_height": height}).encode(),
}
class Text2ImageDataset:
def __init__(
self,
train_shards_path_or_url: Union[str, List[str]],
eval_shards_path_or_url: Union[str, List[str]],
tokenizer: PreTrainedTokenizer,
max_seq_length: int,
num_train_examples: int,
per_gpu_batch_size: int,
global_batch_size: int,
num_workers: int,
resolution: int = 256,
center_crop: bool = True,
random_flip: bool = False,
shuffle_buffer_size: int = 1000,
pin_memory: bool = False,
persistent_workers: bool = False,
is_pre_encoded: bool = False,
vae_checkpoint: Optional[str] = None,
text_encoder_checkpoint: Optional[str] = None,
use_filtered_dataset: bool = False,
require_marked_as_ok_by_spawning: bool = False,
require_marked_as_not_getty: bool = False,
max_pnsfw: Optional[float] = None,
max_pwatermark: Optional[float] = 0.5,
min_aesthetic_score: Optional[float] = 4.75,
min_size: Optional[int] = 256,
is_sdxl_synthetic_dataset: bool = False,
is_ds_clean_upscaled: bool = False,
is_ds_clean: bool = False,
):
if f"{train_shards_path_or_url}.yaml" in os.listdir('./configs'):
with open(f"./configs/{train_shards_path_or_url}.yaml") as f:
train_shards_path_or_url = yaml.safe_load(f)
transform = ImageNetTransform(resolution, center_crop, random_flip)
def tokenize(text):
text = replace_person_token(text)
input_ids = tokenizer(
text, max_length=max_seq_length, padding="max_length", truncation=True, return_tensors="pt"
).input_ids
return input_ids[0]
if not isinstance(train_shards_path_or_url, str):
train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url]
# flatten list using itertools
train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))
if not isinstance(eval_shards_path_or_url, str):
eval_shards_path_or_url = [list(braceexpand(urls)) for urls in eval_shards_path_or_url]
# flatten list using itertools
eval_shards_path_or_url = list(itertools.chain.from_iterable(eval_shards_path_or_url))
if not is_pre_encoded:
processing_pipeline = [
wds.decode("pil", handler=wds.ignore_and_continue),
wds.rename(
image="jpg;png;jpeg;webp",
input_ids="text;txt;caption",
orig_size="json",
aesthetic_score="json",
handler=wds.warn_and_continue,
),
wds.map(filter_keys(set(["image", "input_ids", "orig_size", "aesthetic_score"]))),
wds.map(partial(image_transform, resolution=resolution), handler=wds.warn_and_continue),
wds.map_dict(
input_ids=tokenize,
orig_size=get_orig_size,
aesthetic_score=get_aesthetic_score,
handler=wds.warn_and_continue,
),
]
else:
# lowercase and replace / with .
vae_checkpoint = vae_checkpoint.lower().replace("/", ".")
text_encoder_checkpoint = text_encoder_checkpoint.lower().replace("/", ".")
processing_pipeline = [
wds.decode(wds.handle_extension("pth", wds.autodecode.torch_loads), handler=wds.ignore_and_continue),
wds.rename(
image_input_ids=f"{vae_checkpoint}.pth",
encoder_hidden_states=f"{text_encoder_checkpoint}.pth",
handler=wds.warn_and_continue,
),
wds.map(filter_keys(set(["image_input_ids", "encoder_hidden_states"]))),
]
if is_sdxl_synthetic_dataset:
select = wds.select(lambda sample: "clip_scores.txt" in sample)
elif use_filtered_dataset:
select = wds.select(
WebdatasetSelect(
require_marked_as_ok_by_spawning=require_marked_as_ok_by_spawning,
require_marked_as_not_getty=require_marked_as_not_getty,
max_pnsfw=max_pnsfw,
max_pwatermark=max_pwatermark,
min_aesthetic_score=min_aesthetic_score,
min_size=min_size,
)
)
else:
select = None
if is_sdxl_synthetic_dataset:
map = wds.map(sdxl_synthetic_dataset_map, handler=wds.ignore_and_continue)
elif is_ds_clean_upscaled:
map = wds.map(ds_clean_upscaled_map)
elif is_ds_clean:
map = wds.map(ds_clean_map)
else:
map = None
# Create train dataset and loader
pipeline = [
wds.ResampledShards(train_shards_path_or_url),
tarfile_to_samples_nothrow,
*([select] if select is not None else []),
*([map] if map is not None else []),
wds.shuffle(shuffle_buffer_size),
*processing_pipeline,
wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
]
num_batches = math.ceil(num_train_examples / global_batch_size)
num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker
num_batches = num_worker_batches * num_workers
num_samples = num_batches * global_batch_size
# each worker is iterating over this
self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
self._train_dataloader = wds.WebLoader(
self._train_dataset,
batch_size=None,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
)
# add meta-data to dataloader instance for convenience
self._train_dataloader.num_batches = num_batches
self._train_dataloader.num_samples = num_samples
# Create eval dataset and loader
pipeline = [
wds.SimpleShardList(eval_shards_path_or_url),
wds.split_by_worker,
wds.tarfile_to_samples(handler=wds.ignore_and_continue),
*processing_pipeline,
wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
]
self._eval_dataset = wds.DataPipeline(*pipeline)
self._eval_dataloader = wds.WebLoader(
self._eval_dataset,
batch_size=None,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
)
@property
def train_dataset(self):
return self._train_dataset
@property
def train_dataloader(self):
return self._train_dataloader
@property
def eval_dataset(self):
return self._eval_dataset
@property
def eval_dataloader(self):
return self._eval_dataloader