def _split_generators()

in tensorflow_datasets/text/c4.py [0:0]


  def _split_generators(self, dl_manager, pipeline):
    dl_manager.download_checksums(_CHECKSUMS_URL)

    # We will automatically download the first default CC version, but others
    # need to be manually downloaded.
    cc_versions = set(self.builder_config.cc_versions)
    files_to_download = {}
    files_to_download["wet_path_urls"] = [
        _WET_PATH_URL.format(cc_version=cc_version)
        for cc_version in cc_versions
    ]
    if self.builder_config.badwords_filter:
      files_to_download["badwords"] = {
          lang: _BADWORDS_URL.format(lang=lang)
          for lang in _BADWORDS_LANGS
          if lang != "en"
      }
      # Use older "en" file for reproducibility of the original C4.
      files_to_download["badwords"]["en"] = _EN_BADWORDS_URL
    if self.builder_config.realnewslike:
      files_to_download["realnews_domains"] = _REALNEWS_DOMAINS_URL
    file_paths = dl_manager.download_and_extract(files_to_download)

    if self.builder_config.webtextlike:
      owt_path = os.path.join(dl_manager.manual_dir, _OPENWEBTEXT_URLS_ZIP)
      if not tf.io.gfile.exists(owt_path):
        raise AssertionError(
            "For the WebText-like config, you must manually download the "
            "following file from {0} and place it in {1}: {2}".format(
                _OPENWEBTEXT_URLS_URL, dl_manager.manual_dir,
                _OPENWEBTEXT_URLS_ZIP))
      file_paths["openwebtext_urls_zip"] = dl_manager.extract(owt_path)

    file_paths = tf.nest.map_structure(os.fspath, file_paths)

    page_content_pcollection = self._get_page_content(pipeline, file_paths,
                                                      dl_manager)

    def _lang_filter(url_and_page, lang):
      _, page = url_and_page
      return page["language"] == lang

    def _filter(url_and_page, lang, predicate_fn):
      return (_lang_filter(url_and_page, lang) and
              c4_utils.get_hashed_url_filter_fn(predicate_fn)(url_and_page))

    train_predicate_fn = lambda x: x % 1000 != 0  # 99.9%
    validation_predicate_fn = lambda x: x % 1000 == 0  # 00.1%

    if len(self.builder_config.languages) == 1:
      # Single-language version.
      return [
          tfds.core.SplitGenerator(
              name=tfds.Split.TRAIN,
              gen_kwargs=dict(
                  split="train",
                  page_content=page_content_pcollection,
                  split_filter_fn=c4_utils.get_hashed_url_filter_fn(
                      predicate_fn=train_predicate_fn)),
          ),
          tfds.core.SplitGenerator(
              name=tfds.Split.VALIDATION,
              gen_kwargs=dict(
                  split="validation",
                  page_content=page_content_pcollection,
                  split_filter_fn=c4_utils.get_hashed_url_filter_fn(
                      predicate_fn=validation_predicate_fn)),
          ),
      ]

    splits = []
    for lang in self.builder_config.languages + [c4_utils.UNKNOWN_LANGUAGE]:
      splits.extend([
          tfds.core.SplitGenerator(
              name=lang,
              gen_kwargs=dict(
                  split=lang,
                  page_content=page_content_pcollection,
                  split_filter_fn=functools.partial(
                      _filter, lang=lang, predicate_fn=train_predicate_fn),
              )),
          tfds.core.SplitGenerator(
              name=f"{lang}-validation",
              gen_kwargs=dict(
                  split=f"{lang}-validation",
                  page_content=page_content_pcollection,
                  split_filter_fn=functools.partial(
                      _filter, lang=lang, predicate_fn=validation_predicate_fn),
              ))
      ])
    return splits