def _generate_examples()

in tensorflow_datasets/vision_language/wit_kaggle/wit_kaggle.py [0:0]


  def _generate_examples(self, pipeline, samples_path, image_pixels_path,
                         image_resnet_path):
    """Processes the dataset and yields examples.

    Args:
      pipeline: the beam pipeline.
      samples_path: path to the split's sentences.
      image_pixels_path: path to the images' pixel representations.
      image_resnet_path: path to the images' pixel representations.

    Returns:
      Examples.
    """
    beam = tfds.core.lazy_imports.apache_beam
    counter = functools.partial(beam.metrics.Metrics.counter, _BEAM_NAMESPACE)

    def _get_csv_reader(filename):
      if filename.suffix == ".gz":
        counter("gz_csv_files").inc()
        g = tf.io.gfile.GFile(filename, "rb")
        f = gzip.open(g, "rt", newline="")
      else:
        counter("normal_csv_files").inc()
        f = tf.io.gfile.GFile(filename, "r")
      return csv.reader(f, delimiter="\t")

    def _read_pixel_rows(filename):
      r"""Contains image_url \t image_pixel \t metadata_url."""
      reader = _get_csv_reader(filename)
      for row in reader:
        counter("pixel_rows").inc()
        image_url, image_representation, metadata_url = row
        if image_url:
          yield [image_url, (image_representation, metadata_url)]
        else:
          counter("pixel_rows_no_image_url").inc()

    def _read_resnet_rows(filename):
      r"""Contains image_url \t resnet_embedding."""
      reader = _get_csv_reader(filename)
      for row in reader:
        counter("resnet_rows").inc()
        image_url, image_representation = row
        if image_url:
          yield [image_url, image_representation]
        else:
          counter("resnet_rows_no_image_url").inc()

    def _read_samples_rows(folder_path):
      """Contains samples: train and test have different fields."""
      for filename in tf.io.gfile.listdir(folder_path):
        file_path = folder_path / filename
        f = tf.io.gfile.GFile(file_path, "r")
        # Limit to 100 MB. Value must be smaller than the C long maximum value.
        csv.field_size_limit(sys.maxsize)
        csv_reader = csv.DictReader(f, delimiter="\t", quoting=csv.QUOTE_ALL)
        for row in csv_reader:
          counter("samples_rows").inc()
          sample = {
              feature_key: row[feature_key] for feature_key in
              self.builder_config.split_specific_features.keys()
          }
          image_url = row["image_url"]
          if image_url:
            yield [image_url, sample]
          else:
            counter("samples_rows_no_image_url").inc()

    def _process_examples(el):
      sample_url, sample_fields = el
      # Each image_url can be associated with multiple samples (e.g., multiple
      # languages).
      for i, sample_info in enumerate(sample_fields["sample_info"]):
        sample_id = f"{i}_{sample_url}"
        sample = {"image_url": sample_url}
        for feature_key in self.builder_config.split_specific_features.keys():
          sample[feature_key] = sample_info[feature_key]
        # Test samples don't have gold captions.
        if "caption_title_and_reference_description" not in sample_info:
          sample["caption_title_and_reference_description"] = ""

        # We output image data only if there is at least one image
        # representation per image_url.
        # Not all of the samples in the competition have corresponding image
        # data. In case multiple different image representations are associated
        # with the same image_url, we don't know which one is correct and don't
        # output any.
        if len(set(sample_fields["image_pixels"])) == 1:
          sample_image, sample_metadata = sample_fields["image_pixels"][0]
          sample["image"] = io.BytesIO(base64.b64decode(sample_image))
          sample["metadata_url"] = sample_metadata
        else:
          if len(set(sample_fields["image_pixels"])) > 1:
            counter("image_pixels_multiple").inc()
          else:
            counter("image_pixels_missing").inc()
            sample["image"] = io.BytesIO(base64.b64decode(_EMPTY_IMAGE_BYTES))
          sample["metadata_url"] = ""

        if len(set(sample_fields["image_resnet"])) == 1:
          image_resnet = [
              float(x) for x in sample_fields["image_resnet"][0].split(",")
          ]
          sample["embedding"] = image_resnet
        else:
          if len(set(sample_fields["image_resnet"])) > 1:
            counter("image_resnet_multiple").inc()
          else:
            counter("image_resnet_missing").inc()
          sample["embedding"] = self.builder_config.empty_resnet_embedding

        yield sample_id, sample

    # Read embeddings and bytes representations from (possibly compressed) csv.
    image_resnet_files = [
        image_resnet_path / f for f in tf.io.gfile.listdir(image_resnet_path)
    ]
    resnet_collection = (
        pipeline
        | "Collection from resnet files" >> beam.Create(image_resnet_files)
        | "Get embeddings per image" >> beam.FlatMap(_read_resnet_rows))

    image_pixel_files = [
        image_pixels_path / f for f in tf.io.gfile.listdir(image_pixels_path)
    ]
    pixel_collection = (
        pipeline
        | "Collection from pixel files" >> beam.Create(image_pixel_files)
        | "Get pixels per image" >> beam.FlatMap(_read_pixel_rows))

    # Read samples from tsv files.
    sample_collection = (
        pipeline
        | "Collection from sample files" >> beam.Create(samples_path)
        | "Get samples" >> beam.FlatMap(_read_samples_rows))

    # Combine the features and yield examples.
    return ({
        "sample_info": sample_collection,
        "image_pixels": pixel_collection,
        "image_resnet": resnet_collection,
    }
            | "Group by image_url" >> beam.CoGroupByKey()
            | "Reshuffle" >> beam.Reshuffle()
            | "Process and yield examples" >> beam.FlatMap(_process_examples))