in examples/imagenet/generate_petastorm_imagenet.py [0:0]
def imagenet_directory_to_petastorm_dataset(imagenet_path, output_url, spark_master=None, parquet_files_count=100,
noun_id_to_text=None):
"""Converts a directory with imagenet data into a petastorm dataset.
Expected directory format is:
>>> nXXXXXXXX/
>>> *.JPEG
>>> nZZZZZZZZ/
>>> *.JPEG
:param imagenet_path: a path to the directory containing ``n*/`` subdirectories. If you are running this script on
a Spark cluster, you should have this file be mounted and accessible to executors.
:param output_url: the location where your dataset will be written to. Should be a url: either
``file://...`` or ``hdfs://...``
:param spark_master: A master parameter used by spark session builder. Use default value (``None``) to use system
environment configured spark cluster. Use ``local[*]`` to run on a local box.
:param noun_id_to_text: A dictionary: ``{noun_id : text}``. If ``None``, this function will download the dictionary
from the Internet.
:return: ``None``
"""
session_builder = SparkSession \
.builder \
.appName('Imagenet Dataset Creation') \
.config('spark.executor.memory', '10g') \
.config('spark.driver.memory', '10g') # Increase the memory if running locally with high number of executors
if spark_master:
session_builder.master(spark_master)
spark = session_builder.getOrCreate()
sc = spark.sparkContext
# Get a list of noun_ids
noun_ids = os.listdir(imagenet_path)
if not all(noun_id.startswith('n') for noun_id in noun_ids):
raise RuntimeError('Directory {} expected to contain only subdirectories with name '
'starting with "n".'.format(imagenet_path))
if not noun_id_to_text:
noun_id_to_text = download_nouns_mapping()
ROWGROUP_SIZE_MB = 256
with materialize_dataset(spark, output_url, ImagenetSchema, ROWGROUP_SIZE_MB):
# list of [(nXXXX, 'noun-text'), ...]
noun_id_text_list = map(lambda noun_id: (noun_id, noun_id_to_text[noun_id]), noun_ids)
# rdd of [(nXXXX, 'noun-text', path), ...]
noun_id_text_image_path_rdd = sc.parallelize(noun_id_text_list, min(len(noun_ids) / 10 + 1, 10000)) \
.flatMap(lambda word_id_label: [word_id_label + (image_path,) for image_path in
glob.glob(os.path.join(imagenet_path, word_id_label[0], '*.JPEG'))])
# rdd of [(nXXXX, 'noun-text', image), ...]
noun_id_text_image_rdd = noun_id_text_image_path_rdd \
.map(lambda id_word_image_path:
{ImagenetSchema.noun_id.name: id_word_image_path[0],
ImagenetSchema.text.name: id_word_image_path[1],
ImagenetSchema.image.name: cv2.imread(id_word_image_path[2])})
# Convert to pyspark.sql.Row
sql_rows_rdd = noun_id_text_image_rdd.map(lambda r: dict_to_spark_row(ImagenetSchema, r))
# Write out the result
spark.createDataFrame(sql_rows_rdd, ImagenetSchema.as_spark_schema()) \
.coalesce(parquet_files_count) \
.write \
.mode('overwrite') \
.option('compression', 'none') \
.parquet(output_url)