def _load_rows_with_predicate()

in petastorm/py_dict_reader_worker.py [0:0]


    def _load_rows_with_predicate(self, pq_file, piece, worker_predicate, shuffle_row_drop_partition):
        """Loads all rows that match a predicate from a piece"""

        # 1. Read all columns needed by predicate and decode
        # 2. Apply the predicate. If nothing matches, exit early
        # 3. Read the remaining columns and decode
        # 4. Combine with columns already decoded for the predicate.

        # Split all column names into ones that are needed by predicateand the rest.
        predicate_column_names = set(worker_predicate.get_fields())

        if not predicate_column_names:
            raise ValueError('At least one field name must be returned by predicate\'s get_field() method')

        all_schema_names = set(field.name for field in self._schema.fields.values())

        invalid_column_names = predicate_column_names - all_schema_names
        if invalid_column_names:
            raise ValueError('At least some column names requested by the predicate ({}) '
                             'are not valid schema names: ({})'.format(', '.join(invalid_column_names),
                                                                       ', '.join(all_schema_names)))

        other_column_names = all_schema_names - predicate_column_names - self._dataset.partitions.partition_names

        # Read columns needed for the predicate
        predicate_rows = self._read_with_shuffle_row_drop(piece, pq_file, predicate_column_names,
                                                          shuffle_row_drop_partition)

        # Decode values
        decoded_predicate_rows = [
            utils.decode_row(_select_cols(row, predicate_column_names), self._schema)
            for row in predicate_rows]

        # Use the predicate to filter
        match_predicate_mask = [worker_predicate.do_include(row) for row in decoded_predicate_rows]

        # Don't have anything left after filtering? Exit early.
        if not any(match_predicate_mask):
            return []

        # Remove rows that were filtered out by the predicate
        filtered_decoded_predicate_rows = [row for i, row in enumerate(decoded_predicate_rows) if
                                           match_predicate_mask[i]]

        if other_column_names:
            # Read remaining columns
            other_rows = self._read_with_shuffle_row_drop(piece, pq_file, other_column_names,
                                                          shuffle_row_drop_partition)

            # Remove rows that were filtered out by the predicate
            filtered_other_rows = [row for i, row in enumerate(other_rows) if match_predicate_mask[i]]

            # Decode remaining columns
            decoded_other_rows = [utils.decode_row(row, self._schema) for row in filtered_other_rows]

            # Merge predicate needed columns with the remaining
            all_cols = [_merge_two_dicts(a, b) for a, b in zip(decoded_other_rows, filtered_decoded_predicate_rows)]
            result = all_cols
        else:
            result = filtered_decoded_predicate_rows

        if self._transform_spec:
            result = _apply_transform_spec(result, self._transform_spec)

        return result