def _get_partition_key_filters()

in data_validation/partition_builder.py [0:0]


    def _get_partition_key_filters(self) -> List[List[List[str]]]:
        """The PartitionBuilder object contains the configuration of the table pairs (source and target)
           to be validated and the args (number of partitions). Generate the partitions for each table
           pair and return the partition filter list for all table pairs . A partition
           filter is the string that is used in the where clause - e.g. 'x >=25 and x <50'. The design
           doc for this section is available in docs/internal/partition_table_prd.md

        Returns:
            A list of list of list of strings for the source and target tables for each table pair
            i.e. (list of strings - 1 per partition) x (source and target) x (number of table pairs)
        """
        master_filter_list = []
        for config_manager in self.config_managers:  # For each pair of tables
            validation_builder = ValidationBuilder(config_manager)

            source_pks, target_pks = [], []
            for pk in config_manager.primary_keys:
                source_pks.append(pk["source_column"])
                target_pks.append(pk["target_column"])

            source_partition_row_builder = PartitionRowBuilder(
                source_pks,
                config_manager.source_client,
                config_manager.source_schema,
                config_manager.source_table,
                config_manager.source_query,
                validation_builder.source_builder,
            )
            source_table = source_partition_row_builder.query
            target_partition_row_builder = PartitionRowBuilder(
                target_pks,
                config_manager.target_client,
                config_manager.target_schema,
                config_manager.target_table,
                config_manager.target_query,
                validation_builder.target_builder,
            )
            target_table = target_partition_row_builder.query

            # Get Source and Target row Count
            source_count = source_partition_row_builder.get_count()
            target_count = target_partition_row_builder.get_count()

            # For some reason Teradata connector returns a dataframe with the count element,
            # while the other connectors return a numpy.int64 value
            if isinstance(source_count, pandas.DataFrame):
                source_count = source_count.values[0][0]
            if isinstance(target_count, pandas.DataFrame):
                target_count = target_count.values[0][0]

            if abs(source_count - target_count) > source_count * 0.1:
                logging.warning(
                    "Source and Target table row counts vary by more than 10%,"
                    "partitioning may result in partitions with very different sizes"
                )

            # Decide on number of partitions after checking number requested is not > number of rows in source
            number_of_part = (
                self.args.partition_num
                if self.args.partition_num < source_count
                else source_count
            )

            # First we number each row in the source table. Using row_number instead of ntile since it is
            # available on all platforms (Teradata does not support NTILE). For our purposes, it is likely
            # more efficient
            window1 = ibis.window(order_by=source_pks)
            row_number = (ibis.row_number().over(window1) + 1).name(consts.DVT_POS_COL)

            if config_manager.trim_string_pks():
                dvt_keys = []
                for key in source_pks.copy():
                    if source_table[key].type().is_string():
                        rstrip_key = source_table[key].rstrip().name(key)
                        dvt_keys.append(rstrip_key)
                    else:
                        dvt_keys.append(key)
            else:
                dvt_keys = source_pks.copy()

            dvt_keys.append(row_number)
            rownum_table = source_table.select(dvt_keys)
            # Rownum table is just the primary key columns in the source table along with
            # an additional column with the row number associated with each row.

            # This rather complicated expression below is a filter (where) clause condition that filters the row numbers
            # that correspond to the first element of the partition. The number of a partition is
            # ceiling(row number * # of partitions / total number of rows). The first element of the partition is where
            # the remainder, i.e. row number * # of partitions % total number of rows is > 0 and <= number of partitions.
            # The remainder function does not work well with Teradata, hence writing that out explicitly.
            cond = (
                rownum_table
                if source_count == number_of_part
                else (
                    (
                        rownum_table[consts.DVT_POS_COL] * number_of_part
                        - (
                            rownum_table[consts.DVT_POS_COL]
                            * number_of_part
                            / source_count
                        ).floor()
                        * source_count
                    )
                    <= number_of_part
                )
                & (
                    (
                        rownum_table[consts.DVT_POS_COL] * number_of_part
                        - (
                            rownum_table[consts.DVT_POS_COL]
                            * number_of_part
                            / source_count
                        ).floor()
                        * source_count
                    )
                    > 0
                )
            )
            first_keys_table = rownum_table[cond].order_by(source_pks)

            # Up until this point, we have built the table expression, have not executed the query yet.
            # The query is now executed to find the first element of each partition
            first_elements = first_keys_table.execute().to_numpy()

            # The objective is to generate the SQL expression string that is saved in the yaml file as a
            # filters property. This SQL expression is used as a filter during validation to ensure
            # that the yaml file is only validating the specific partition. This string is backend specific as
            # the SQL syntax varies slightly across backends. We get Ibis to generate the string for
            # a table expression with the filter (where) clause and then extract the SQL expression string.
            # The function _extract_where extracts the expression string from the Ibis SQL table expression.

            # Once we have the first element of each partition, we can generate the where clause
            # i.e. greater than or equal to first element and less than first element of next partition
            # The first and the last partitions have special where clauses - less than first element of second
            # partition and greater than or equal to the first element of the last partition respectively

            source_where_list = []
            target_where_list = []

            # Given a list of primary keys and corresponding values, the following lambda function builds the filter expression
            # to find all rows before the row containing the values in the sort order. The next function geq_value, finds all
            # rows after the row containing the values in the sort order, including the row specified by values.

            def less_than_value(table, keys, values):
                key_column = table.__getattr__(keys[0])
                # Due to issue 1474, the type can be datetime.datetime or datetime.date
                value = (
                    values[0].date()
                    if key_column.type().is_date()
                    and isinstance(values[0], datetime.datetime)
                    else values[0]
                )
                if len(keys) == 1:
                    return key_column < value
                else:
                    return (key_column < value) | (
                        (key_column == value)
                        & less_than_value(table, keys[1:], values[1:])
                    )

            def geq_value(table, keys, values):
                key_column = table.__getattr__(keys[0])
                # Due to issue 1474, the type can be datetime.datetime or datetime.date
                value = (
                    values[0].date()
                    if key_column.type().is_date()
                    and isinstance(values[0], datetime.datetime)
                    else values[0]
                )

                if len(keys) == 1:
                    return key_column >= value
                else:
                    return (key_column > value) | (
                        (key_column == value) & geq_value(table, keys[1:], values[1:])
                    )

            filter_source_clause = less_than_value(
                source_table,
                source_pks,
                first_elements[1, : len(source_pks)],
            )
            filter_target_clause = less_than_value(
                target_table,
                target_pks,
                first_elements[1, : len(target_pks)],
            )
            source_where_list.append(
                self._extract_where(
                    source_table.filter(filter_source_clause),
                )
            )
            target_where_list.append(
                self._extract_where(
                    target_table.filter(filter_target_clause),
                )
            )

            for i in range(1, first_elements.shape[0] - 1):
                filter_source_clause = geq_value(
                    source_table,
                    source_pks,
                    first_elements[i, : len(source_pks)],
                ) & less_than_value(
                    source_table,
                    source_pks,
                    first_elements[i + 1, : len(source_pks)],
                )
                filter_target_clause = geq_value(
                    target_table,
                    target_pks,
                    first_elements[i, : len(target_pks)],
                ) & less_than_value(
                    target_table,
                    target_pks,
                    first_elements[i + 1, : len(target_pks)],
                )
                source_where_list.append(
                    self._extract_where(
                        source_table.filter(filter_source_clause),
                    )
                )
                target_where_list.append(
                    self._extract_where(
                        target_table.filter(filter_target_clause),
                    )
                )
            filter_source_clause = geq_value(
                source_table,
                source_pks,
                first_elements[len(first_elements) - 1, : len(source_pks)],
            )
            filter_target_clause = geq_value(
                target_table,
                target_pks,
                first_elements[len(first_elements) - 1, : len(target_pks)],
            )
            source_where_list.append(
                self._extract_where(
                    source_table.filter(filter_source_clause),
                )
            )
            target_where_list.append(
                self._extract_where(
                    target_table.filter(filter_target_clause),
                )
            )
            master_filter_list.append([source_where_list, target_where_list])
        return master_filter_list