def load_dataset_from_disk()

in community-artifacts/Deep-learning/Utilities/madlib_image_loader.py [0:0]


    def load_dataset_from_disk(self, root_dir, table_name, num_labels='all',
                               append=False, label_datatype='TEXT'):
        """
        Load images from disk into a greenplum database table. All the images
            should be of the same shape.
        @root_dir: Location of the dir which contains all the labels and their
            associated images. Can be relative or absolute. Each label needs to
            have it's own dir and should contain only images inside it's own dir.
            (Extra files in root dir will be ignored, only diretories matter.)
        @table_name: Name of destination table in db
        @num_labels: Num of labels to process/load into a table. By default all
            the labels are loaded.  @table_name: Name of the database table into
            which images will be loaded.
        @append: If set to true, do not create a new table but append to an
            existing table.
        @label_datatype: If set will create table with the the column 'y' set
            to the datatype specified. Default is set to TEXT
        """
        start_time = time.time()
        self.mother = True
        self.append = append
        self.no_temp_files = False
        self.table_name = table_name
        self.label_datatype = label_datatype
        self.from_disk = True
        self._validate_input_and_create_table()

        self.root_dir = root_dir
        subdirs = os.listdir(root_dir)

        labels = []
        # Prune files from directory listing, only use actual sub-directories
        #  This allows the user to keep a tar.gz file or other extraneous files
        #  in the root directory without causing any problems.
        for subdir in subdirs:
            if os.path.isdir(os.path.join(root_dir,subdir)):
                labels.append(subdir)
            else:
                print("{0} is not a directory, skipping".format(subdir))

        if num_labels == 'all':
            print('number of labels = {}'.format(len(labels)))
            num_labels = len(labels)
            print "Found {0} image labels in {1}".format(num_labels, root_dir)
        else:
            num_labels = int(num_labels)
            labels = labels[:num_labels]
            print "Using first {0} image labels in {1}".format(num_labels,
                                                               root_dir)

        if not self.pool:
            print("Spawning {0} workers...".format(self.num_workers))
            self.pool = Pool(processes=self.num_workers,
                             initializer=init_worker,
                             initargs=(current_process().pid,
                                       self.table_name,
                                       self.append,
                                       self.no_temp_files,
                                       self.db_creds,
                                       self.from_disk,
                                       root_dir))
        try:
            self.pool.map(_call_disk_worker, labels)
        except(Exception) as e:
            self.terminate_workers()
            raise e

        self.pool.map(_worker_cleanup, [0] * self.num_workers)

        end_time = time.time()
        print("Done!  Loaded {0} image categories in {1}s"\
            .format(len(labels), end_time - start_time))

        self.terminate_workers()