in tensorflow_datasets/vision_language/wit_kaggle/wit_kaggle.py [0:0]
def _generate_examples(self, pipeline, samples_path, image_pixels_path,
image_resnet_path):
"""Processes the dataset and yields examples.
Args:
pipeline: the beam pipeline.
samples_path: path to the split's sentences.
image_pixels_path: path to the images' pixel representations.
image_resnet_path: path to the images' pixel representations.
Returns:
Examples.
"""
beam = tfds.core.lazy_imports.apache_beam
counter = functools.partial(beam.metrics.Metrics.counter, _BEAM_NAMESPACE)
def _get_csv_reader(filename):
if filename.suffix == ".gz":
counter("gz_csv_files").inc()
g = tf.io.gfile.GFile(filename, "rb")
f = gzip.open(g, "rt", newline="")
else:
counter("normal_csv_files").inc()
f = tf.io.gfile.GFile(filename, "r")
return csv.reader(f, delimiter="\t")
def _read_pixel_rows(filename):
r"""Contains image_url \t image_pixel \t metadata_url."""
reader = _get_csv_reader(filename)
for row in reader:
counter("pixel_rows").inc()
image_url, image_representation, metadata_url = row
if image_url:
yield [image_url, (image_representation, metadata_url)]
else:
counter("pixel_rows_no_image_url").inc()
def _read_resnet_rows(filename):
r"""Contains image_url \t resnet_embedding."""
reader = _get_csv_reader(filename)
for row in reader:
counter("resnet_rows").inc()
image_url, image_representation = row
if image_url:
yield [image_url, image_representation]
else:
counter("resnet_rows_no_image_url").inc()
def _read_samples_rows(folder_path):
"""Contains samples: train and test have different fields."""
for filename in tf.io.gfile.listdir(folder_path):
file_path = folder_path / filename
f = tf.io.gfile.GFile(file_path, "r")
# Limit to 100 MB. Value must be smaller than the C long maximum value.
csv.field_size_limit(sys.maxsize)
csv_reader = csv.DictReader(f, delimiter="\t", quoting=csv.QUOTE_ALL)
for row in csv_reader:
counter("samples_rows").inc()
sample = {
feature_key: row[feature_key] for feature_key in
self.builder_config.split_specific_features.keys()
}
image_url = row["image_url"]
if image_url:
yield [image_url, sample]
else:
counter("samples_rows_no_image_url").inc()
def _process_examples(el):
sample_url, sample_fields = el
# Each image_url can be associated with multiple samples (e.g., multiple
# languages).
for i, sample_info in enumerate(sample_fields["sample_info"]):
sample_id = f"{i}_{sample_url}"
sample = {"image_url": sample_url}
for feature_key in self.builder_config.split_specific_features.keys():
sample[feature_key] = sample_info[feature_key]
# Test samples don't have gold captions.
if "caption_title_and_reference_description" not in sample_info:
sample["caption_title_and_reference_description"] = ""
# We output image data only if there is at least one image
# representation per image_url.
# Not all of the samples in the competition have corresponding image
# data. In case multiple different image representations are associated
# with the same image_url, we don't know which one is correct and don't
# output any.
if len(set(sample_fields["image_pixels"])) == 1:
sample_image, sample_metadata = sample_fields["image_pixels"][0]
sample["image"] = io.BytesIO(base64.b64decode(sample_image))
sample["metadata_url"] = sample_metadata
else:
if len(set(sample_fields["image_pixels"])) > 1:
counter("image_pixels_multiple").inc()
else:
counter("image_pixels_missing").inc()
sample["image"] = io.BytesIO(base64.b64decode(_EMPTY_IMAGE_BYTES))
sample["metadata_url"] = ""
if len(set(sample_fields["image_resnet"])) == 1:
image_resnet = [
float(x) for x in sample_fields["image_resnet"][0].split(",")
]
sample["embedding"] = image_resnet
else:
if len(set(sample_fields["image_resnet"])) > 1:
counter("image_resnet_multiple").inc()
else:
counter("image_resnet_missing").inc()
sample["embedding"] = self.builder_config.empty_resnet_embedding
yield sample_id, sample
# Read embeddings and bytes representations from (possibly compressed) csv.
image_resnet_files = [
image_resnet_path / f for f in tf.io.gfile.listdir(image_resnet_path)
]
resnet_collection = (
pipeline
| "Collection from resnet files" >> beam.Create(image_resnet_files)
| "Get embeddings per image" >> beam.FlatMap(_read_resnet_rows))
image_pixel_files = [
image_pixels_path / f for f in tf.io.gfile.listdir(image_pixels_path)
]
pixel_collection = (
pipeline
| "Collection from pixel files" >> beam.Create(image_pixel_files)
| "Get pixels per image" >> beam.FlatMap(_read_pixel_rows))
# Read samples from tsv files.
sample_collection = (
pipeline
| "Collection from sample files" >> beam.Create(samples_path)
| "Get samples" >> beam.FlatMap(_read_samples_rows))
# Combine the features and yield examples.
return ({
"sample_info": sample_collection,
"image_pixels": pixel_collection,
"image_resnet": resnet_collection,
}
| "Group by image_url" >> beam.CoGroupByKey()
| "Reshuffle" >> beam.Reshuffle()
| "Process and yield examples" >> beam.FlatMap(_process_examples))