dataflux_core/range_splitter.py (122 lines of code) (raw):

""" Copyright 2023 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from __future__ import annotations from collections.abc import Sequence from dataclasses import dataclass from fractions import Fraction from itertools import count @dataclass class MinimalIntRange: start_int: int end_int: int min_len: int @dataclass class GenerateSplitsOpts: min_int_range: MinimalIntRange num_splits: int start_range: str end_range: str class RangeSplitter(object): """Manages splits performed to facilitate the work-stealing algorithm. Attr: alphabet_map: An int mapping for an alphabet of arbitrary character size. sorted_alphabet: The sorted alphabet that initializes the RangeSplitter. """ min_splits = 2 def __init__(self, alphabet_map: dict[int, str], sorted_alphabet: Sequence[str]): self.alphabet_map = alphabet_map self.sorted_alphabet = sorted_alphabet self.alphabet_set = set(sorted_alphabet) def split_range( self, start_range: str, end_range: str, num_splits: int, ) -> Sequence[str]: """Creates a given number of splits based on a provided start and end range. Args: start_range (str): The string marking the start of the split range. end_range (str): The string marking the end of the split range. num_splits (int): The number of splitpoints to return. Returns: A sequence of split points dividing up the provided range. """ if num_splits < 1: raise ValueError("Got num_splits of %s but need minimum of %s." % (num_splits, self.min_splits)) if len(end_range) != 0 and start_range >= end_range: return [] if self.is_range_equal_with_padding(start_range, end_range): return [] self.add_characters_to_alphabet(start_range + end_range) min_int_range = self.string_to_minimal_int_range( start_range, end_range, num_splits) split_points = self.generate_splits( GenerateSplitsOpts(min_int_range, num_splits, start_range, end_range)) return split_points def generate_splits(self, opts: GenerateSplitsOpts) -> Sequence[str]: """Generates a list of split points. Args: opts (GenerateSplitOpts): Set of options for generating splitpoints Returns: A list of split points. """ start_int = opts.min_int_range.start_int end_int = opts.min_int_range.end_int min_len = opts.min_int_range.min_len range_diff = end_int - start_int split_points = [] range_interval = opts.num_splits + 1 adjustment = Fraction(range_diff / range_interval) for i in range(1, opts.num_splits + 1): split_point = start_int + adjustment * i split_string = self.int_to_string(int(split_point), min_len) is_greater_than_start = (len(split_string) > 0 and split_string > opts.start_range) is_less_than_end = len( opts.end_range) == 0 or (len(split_string) > 0 and split_string < opts.end_range) if is_greater_than_start and is_less_than_end: split_points.append(split_string) return split_points def int_to_string(self, split_point: int, string_len: int) -> str: """Converts the base len(alphabet) int back into a string. Args: split_point (int): A valid split point int to be converted to string. string_len (int): The required length of the resulting string. Returns: A string derived from a base len(alphabet) int. """ alphabet_len = len(self.sorted_alphabet) split_string = "" for _ in range(string_len): remainder = split_point % alphabet_len split_point //= alphabet_len split_string += self.sorted_alphabet[remainder] # This is assembeled backwards via division, so we reverse the final string. return split_string[::-1] def string_to_minimal_int_range(self, start_range: str, end_range: str, num_splits: int) -> MinimalIntRange: """Converts a string range to a minimal integer range. Args: start_range (str): The string marking the start of the split range. end_range (str): The string marking the end of the split range. num_splits (int): The number of splitpoints to return. Returns: A minimal integer range. """ start_int = 0 end_int = 0 alphabet_len = len(self.sorted_alphabet) start_char = self.sorted_alphabet[0] end_char = self.sorted_alphabet[-1] end_default_char = start_char if len(end_range) == 0: end_default_char = end_char for i in count(0): start_pos = self.alphabet_map[get_char_or_default( start_range, i, start_char)] start_int *= alphabet_len start_int += start_pos end_pos = self.alphabet_map[get_char_or_default( end_range, i, end_default_char)] end_int *= alphabet_len end_int += end_pos difference = end_int - start_int if difference > num_splits: # Due to zero indexing, min length must have 1 added to it. return MinimalIntRange(start_int, end_int, i + 1) def is_range_equal_with_padding(self, start_range: str, end_range: str): """Checks for equality between two string ranges. Args: start_range (str): The start range for the split. end_range (str): The end range for the split. Returns: Boolean indicating equality of the two provided ranges. """ if len(end_range) == 0: return False longest = max(len(start_range), len(end_range)) smallest_char = self.sorted_alphabet[0] for i in range(longest): char_start = get_char_or_default(start_range, i, smallest_char) char_end = get_char_or_default(end_range, i, smallest_char) if char_start != char_end: return False return True def add_characters_to_alphabet(self, characters: str): """Adds a character to the known alphabet. Args: characters: The string of characters to add to the library. """ unique_characters = set(characters) new_alphabet = self.alphabet_set.union(unique_characters) if len(new_alphabet) != len(self.alphabet_set): self.sorted_alphabet = sorted(new_alphabet) self.alphabet_map = { val: index for index, val in enumerate(self.sorted_alphabet) } def get_char_or_default(characters: str, index: int, default_char: str) -> str: """Returns the character at the given index or the default character if the index is out of bounds. Args: characters (str): The range string to check. index (int): The current iteration index across characters. default_char (str): The smallest character in the implemented char set. Returns: The resulting character for the given index. """ if index < 0 or index >= len(characters): return default_char return characters[index] def new_rangesplitter(alphabet: str) -> RangeSplitter: """Creates a new RangeSplitter with the given alphabets. Note that the alphabets are a predetermined set of characters by the work-stealing algorithm, and the characters are guaranteed to be unique. Args: alphabet (str): The full set of characters used for this range splitter. Returns: An instance of the RangeSplitter class that is used to manage splits performed to facilitate the work-stealing algorithm. """ if len(alphabet) == 0: raise ValueError("Cannot split with an empty alphabet.") sorted_alphabet = sorted(alphabet) alphabet_map = {val: index for index, val in enumerate(sorted_alphabet)} return RangeSplitter(alphabet_map, sorted_alphabet)