in tensorflow_datasets/text/c4.py [0:0]
def _split_generators(self, dl_manager, pipeline):
dl_manager.download_checksums(_CHECKSUMS_URL)
# We will automatically download the first default CC version, but others
# need to be manually downloaded.
cc_versions = set(self.builder_config.cc_versions)
files_to_download = {}
files_to_download["wet_path_urls"] = [
_WET_PATH_URL.format(cc_version=cc_version)
for cc_version in cc_versions
]
if self.builder_config.badwords_filter:
files_to_download["badwords"] = {
lang: _BADWORDS_URL.format(lang=lang)
for lang in _BADWORDS_LANGS
if lang != "en"
}
# Use older "en" file for reproducibility of the original C4.
files_to_download["badwords"]["en"] = _EN_BADWORDS_URL
if self.builder_config.realnewslike:
files_to_download["realnews_domains"] = _REALNEWS_DOMAINS_URL
file_paths = dl_manager.download_and_extract(files_to_download)
if self.builder_config.webtextlike:
owt_path = os.path.join(dl_manager.manual_dir, _OPENWEBTEXT_URLS_ZIP)
if not tf.io.gfile.exists(owt_path):
raise AssertionError(
"For the WebText-like config, you must manually download the "
"following file from {0} and place it in {1}: {2}".format(
_OPENWEBTEXT_URLS_URL, dl_manager.manual_dir,
_OPENWEBTEXT_URLS_ZIP))
file_paths["openwebtext_urls_zip"] = dl_manager.extract(owt_path)
file_paths = tf.nest.map_structure(os.fspath, file_paths)
page_content_pcollection = self._get_page_content(pipeline, file_paths,
dl_manager)
def _lang_filter(url_and_page, lang):
_, page = url_and_page
return page["language"] == lang
def _filter(url_and_page, lang, predicate_fn):
return (_lang_filter(url_and_page, lang) and
c4_utils.get_hashed_url_filter_fn(predicate_fn)(url_and_page))
train_predicate_fn = lambda x: x % 1000 != 0 # 99.9%
validation_predicate_fn = lambda x: x % 1000 == 0 # 00.1%
if len(self.builder_config.languages) == 1:
# Single-language version.
return [
tfds.core.SplitGenerator(
name=tfds.Split.TRAIN,
gen_kwargs=dict(
split="train",
page_content=page_content_pcollection,
split_filter_fn=c4_utils.get_hashed_url_filter_fn(
predicate_fn=train_predicate_fn)),
),
tfds.core.SplitGenerator(
name=tfds.Split.VALIDATION,
gen_kwargs=dict(
split="validation",
page_content=page_content_pcollection,
split_filter_fn=c4_utils.get_hashed_url_filter_fn(
predicate_fn=validation_predicate_fn)),
),
]
splits = []
for lang in self.builder_config.languages + [c4_utils.UNKNOWN_LANGUAGE]:
splits.extend([
tfds.core.SplitGenerator(
name=lang,
gen_kwargs=dict(
split=lang,
page_content=page_content_pcollection,
split_filter_fn=functools.partial(
_filter, lang=lang, predicate_fn=train_predicate_fn),
)),
tfds.core.SplitGenerator(
name=f"{lang}-validation",
gen_kwargs=dict(
split=f"{lang}-validation",
page_content=page_content_pcollection,
split_filter_fn=functools.partial(
_filter, lang=lang, predicate_fn=validation_predicate_fn),
))
])
return splits