perfkitbenchmarker/flag_util.py (276 lines of code) (raw):

# Copyright 2018 PerfKitBenchmarker Authors. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utility functions for working with user-supplied flags.""" import enum import logging import os import re from absl import flags from perfkitbenchmarker import errors from perfkitbenchmarker import flag_alias from perfkitbenchmarker import units import yaml FLAGS = flags.FLAGS INTEGER_GROUP_REGEXP = re.compile(r'(\d+)(-(\d+))?(-(\d+))?$') INTEGER_GROUP_REGEXP_COLONS = re.compile(r'(-?\d+)(:(-?\d+))?(:(-?\d+))?$') class IntegerList: """An immutable list of nonnegative integers. The list contains either single integers (ex: 5) or ranges (ex: 8-12). Additionally, the user can provide a step to the range like so: 8-24-2. The list can include as many elements as will fit in memory. Furthermore, the memory required to hold a range will not grow with the size of the range. Make a list with lst = IntegerList(groups) where groups is a list whose elements are either single integers, 2-tuples holding the low and high bounds of a range (inclusive), or 3-tuples holding the low and high bounds, followed by the step size. (Ex: [5, (8,12)] represents the integer list 5,8,9,10,11,12, and [(8-14-2)] represents the list 8,10,12,14.) For negative number ranges use a colon separator (ex: "-2:1" is the integer list -2, -1, 0, 1). """ def __init__(self, groups): self.groups = groups length = 0 for elt in groups: if isinstance(elt, int): length += 1 if isinstance(elt, tuple): length += len(self._CreateXrangeFromTuple(elt)) self.length = length def __len__(self): return self.length def __getitem__(self, idx): if not isinstance(idx, int): raise TypeError() if idx < 0 or idx >= self.length: raise IndexError() group_idx = 0 while idx > 0: group = self.groups[group_idx] if not isinstance(group, tuple): group_idx += 1 idx -= 1 else: group_len = len(self._CreateXrangeFromTuple(group)) if idx >= group_len: group_idx += 1 idx -= group_len else: step = 1 if len(group) == 2 else group[2] return group[0] + idx * step if isinstance(self.groups[group_idx], tuple): return self.groups[group_idx][0] else: return self.groups[group_idx] def __eq__(self, other): if other is None: return False return tuple(self) == tuple(other) def __ne__(self, other): if other is None: return True return tuple(self) != tuple(other) def __iter__(self): for group in self.groups: if isinstance(group, int): yield group else: yield from self._CreateXrangeFromTuple(group) def __str__(self): return IntegerListSerializer().serialize(self) def __repr__(self): return 'IntegerList([%s])' % self def _CreateXrangeFromTuple(self, input_tuple): start = input_tuple[0] step = 1 if len(input_tuple) == 2 else input_tuple[2] stop_inclusive = input_tuple[1] + (1 if step > 0 else -1) return range(start, stop_inclusive, step) def _IsNonIncreasing(result, val): """Determines if result would be non-increasing if val is appended. Args: result: list integers and/or range tuples. val: integer or range tuple to append. Returns: bool indicating if the appended list is non-increasing. """ if result: if isinstance(result[-1], tuple): # extract high from previous tuple prev = result[-1][1] else: # previous is int prev = result[-1] if val <= prev: return True return False class IntegerListParser(flags.ArgumentParser): """Parse a string containing a comma-separated list of nonnegative integers. The list may contain single integers and dash-separated ranges. For example, "1,3,5-7" parses to [1,3,5,6,7] and "1-7-3" parses to [1,4,7]. Can pass the flag on_nonincreasing to the constructor to tell it what to do if the list is nonincreasing. Options are - None: do nothing. - IntegerListParser.WARN: log a warning. - IntegerListParser.EXCEPTION: raise a ValueError. As a special case, instead of a string, can pass a list of integers or an IntegerList. In these cases, the return value iterates over the same integers as were in the argument. For negative number ranges use a colon separator, for example "-3:4:2" parses to [-3, -1, 1, 3]. """ syntactic_help = ( 'A comma-separated list of integers or integer ' 'ranges. Ex: -1,3,5:7 is read as -1,3,5,6,7.' ) WARN = 'warn' EXCEPTION = 'exception' def __init__(self, on_nonincreasing=None): super().__init__() self.on_nonincreasing = on_nonincreasing def parse(self, inp): """Parse an integer list. Args: inp: a string, a list, or an IntegerList. Returns: An iterable of integers. Raises: ValueError: if inp doesn't follow a format it recognizes. """ if isinstance(inp, IntegerList): return inp elif isinstance(inp, list): return IntegerList(inp) elif isinstance(inp, int): return IntegerList([inp]) def HandleNonIncreasing(): if self.on_nonincreasing == IntegerListParser.WARN: logging.warning('Integer list %s is not increasing', inp) elif self.on_nonincreasing == IntegerListParser.EXCEPTION: raise ValueError('Integer list %s is not increasing' % inp) groups = inp.split(',') result = [] for group in groups: match = INTEGER_GROUP_REGEXP.match( group ) or INTEGER_GROUP_REGEXP_COLONS.match(group) if match is None: raise ValueError('Invalid integer list %s' % inp) elif match.group(2) is None: val = int(match.group(1)) if _IsNonIncreasing(result, val): HandleNonIncreasing() result.append(val) else: low = int(match.group(1)) high = int(match.group(3)) step = int(match.group(5)) if match.group(5) is not None else 1 step = -step if step > 0 and low > high else step if high <= low or (_IsNonIncreasing(result, low)): HandleNonIncreasing() result.append((low, high, step)) return IntegerList(result) def flag_type(self): return 'integer list' class IntegerListSerializer(flags.ArgumentSerializer): def _SerializeRange(self, val): separator = ':' if any(item < 0 for item in val) else '-' return separator.join(str(item) for item in val) def serialize(self, il): return ','.join([ str(val) if isinstance(val, int) else self._SerializeRange(val) for val in il.groups ]) def DEFINE_integerlist( name, default, help, on_nonincreasing=None, flag_values=FLAGS, **kwargs ): """Register a flag whose value must be an integer list.""" parser = IntegerListParser(on_nonincreasing=on_nonincreasing) serializer = IntegerListSerializer() flags.DEFINE(parser, name, default, help, flag_values, serializer, **kwargs) class OverrideFlags: """Context manager that applies any config_dict overrides to flag_values.""" def __init__( self, flag_values, config_dict, alias=flag_alias.ALL_TRANSLATIONS ): """Initializes an OverrideFlags context manager. Args: flag_values: FlagValues that is temporarily modified so that any options in config_dict that are not 'present' in flag_values are applied to flag_values. Upon exit, flag_values will be restored to its original state. config_dict: Merged config flags from the benchmark config and benchmark configuration yaml file. alias: Alias to rename the flags to. """ self._flag_values = flag_values self._config_dict = flag_alias.AliasFlagsFromYaml(config_dict, alias) self._flags_to_reapply = {} def __enter__(self): """Overrides flag_values with options in config_dict.""" if not self._config_dict: return for key, value in self._config_dict.items(): if key not in self._flag_values: raise errors.Config.UnrecognizedOption( 'Unrecognized option {0}.{1}. Each option within {0} must ' 'correspond to a valid command-line flag.'.format('flags', key) ) if not self._flag_values[key].present: self._flags_to_reapply[key] = self._flag_values[key].value try: self._flag_values[key].parse(value) # Set 'present' to True. except flags.IllegalFlagValueError as e: raise errors.Config.InvalidValue( 'Invalid {}.{} value: "{}" (of type "{}").{}{}'.format( 'flags', key, value, value.__class__.__name__, os.linesep, e ) ) def __exit__(self, *unused_args, **unused_kwargs): """Restores flag_values to its original state.""" if not self._flags_to_reapply: return for key, value in self._flags_to_reapply.items(): self._flag_values[key].value = value self._flag_values[key].present = 0 class UnitsParser(flags.ArgumentParser): """Parse a flag containing a unit expression. Attributes: convertible_to: list of units.Unit instances. A parsed expression must be convertible to at least one of the Units in this list. For example, if the parser requires that its inputs are convertible to bits, then values expressed in KiB and GB are valid, but values expressed in meters are not. """ syntactic_help = 'A quantity with a unit. Ex: 12.3MB.' def __init__(self, convertible_to): """Initialize the UnitsParser. Args: convertible_to: Either an individual unit specification or a series of unit specifications, where each unit specification is either a string (e.g. 'byte') or a units.Unit. The parser input must be convertible to at least one of the specified Units, or the parse() method will raise a ValueError. """ if isinstance(convertible_to, ((str,), units.Unit)): self.convertible_to = [units.Unit(convertible_to)] else: self.convertible_to = [units.Unit(u) for u in convertible_to] def parse(self, inp): """Parse the input. Args: inp: a string or a units.Quantity. If a string, it has the format "<number><units>", as in "12KB", or "2.5GB". Returns: A units.Quantity. Raises: ValueError: If the input cannot be parsed, or if it parses to a value with improper units. """ if isinstance(inp, units.Quantity): quantity = inp else: try: quantity = units.ParseExpression(inp) except Exception as e: raise ValueError( "Couldn't parse unit expression %r: %s" % (inp, str(e)) ) if not isinstance(quantity, units.Quantity): raise ValueError('Expression %r evaluates to a unitless value.' % inp) for unit in self.convertible_to: try: quantity.to(unit) break except units.DimensionalityError: pass else: raise ValueError( 'Expression {!r} is not convertible to an acceptable unit ' '({}).'.format(inp, ', '.join(str(u) for u in self.convertible_to)) ) return quantity class UnitsSerializer(flags.ArgumentSerializer): def serialize(self, units): return str(units) def DEFINE_units( name, default, help, convertible_to, flag_values=flags.FLAGS, **kwargs ): """Register a flag whose value is a units expression. Args: name: string. The name of the flag. default: units.Quantity. The default value. help: string. A help message for the user. convertible_to: Either an individual unit specification or a series of unit specifications, where each unit specification is either a string (e.g. 'byte') or a units.Unit. The flag value must be convertible to at least one of the specified Units to be considered valid. flag_values: the absl.flags.FlagValues object to define the flag in. """ parser = UnitsParser(convertible_to=convertible_to) serializer = UnitsSerializer() flags.DEFINE(parser, name, default, help, flag_values, serializer, **kwargs) def StringToBytes(string): """Convert an object size, represented as a string, to bytes. Args: string: the object size, as a string with a quantity and a unit. Returns: an integer. The number of bytes in the size. Raises: ValueError, if either the string does not represent an object size or if the size does not contain an integer number of bytes. """ try: quantity = units.ParseExpression(string) except Exception: # Catching all exceptions is ugly, but we don't know what sort of # exception pint might throw, and we want to turn any of them into # ValueError. raise ValueError("Couldn't parse size %s" % string) try: bytes = quantity.m_as(units.byte) except units.DimensionalityError: raise ValueError('Quantity %s is not a size' % string) if bytes != int(bytes): raise ValueError( 'Size %s has a non-integer number (%s) of bytes!' % (string, bytes) ) if bytes < 0: raise ValueError('Size %s has a negative number of bytes!' % string) return int(bytes) def StringToRawPercent(string): """Convert a string to a raw percentage value. Args: string: the percentage, with '%' on the end. Returns: A floating-point number, holding the percentage value. Raises: ValueError, if the string can't be read as a percentage. """ if len(string) <= 1: raise ValueError("String '%s' too short to be percentage." % string) if string[-1] != '%': raise ValueError("Percentage '%s' must end with '%%'" % string) # This will raise a ValueError if it can't convert the string to a float. val = float(string[:-1]) if val < 0.0 or val > 100.0: raise ValueError('Quantity %s is not a valid percentage' % val) return val # The YAML flag type is necessary because flags can be read either via # the command line or from a config file. If they come from a config # file, they will already be parsed as YAML, but if they come from the # command line, they will be raw strings. The point of this flag is to # guarantee a consistent representation to the rest of the program. class YAMLParser(flags.ArgumentParser): """Parse a flag containing YAML.""" syntactic_help = 'A YAML expression.' def parse(self, inp): """Parse the input. Args: inp: A string or the result of yaml.safe_load. If a string, should be a valid YAML document. """ if isinstance(inp, str): # This will work unless the user writes a config with a quoted # string that, if unquoted, would be parsed as a non-string # Python type (example: '123'). In that case, the first # yaml.safe_load() in the config system will strip away the quotation # marks, and this second yaml.safe_load() will parse it as the # non-string type. However, I think this is the best we can do # without significant changes to the config system, and the # problem is unlikely to occur in PKB. try: return yaml.safe_load(inp) except yaml.YAMLError as e: raise ValueError("Couldn't parse YAML string '%s': %s" % (inp, str(e))) else: return inp class YAMLSerializer(flags.ArgumentSerializer): def serialize(self, val): return yaml.dump(val) def DEFINE_yaml(name, default, help, flag_values=flags.FLAGS, **kwargs): """Register a flag whose value is a YAML expression. Args: name: string. The name of the flag. default: object. The default value of the flag. help: string. A help message for the user. flag_values: the absl.flags.FlagValues object to define the flag in. kwargs: extra arguments to pass to absl.flags.DEFINE(). """ parser = YAMLParser() serializer = YAMLSerializer() flags.DEFINE(parser, name, default, help, flag_values, serializer, **kwargs) def ParseKeyValuePairs(strings): """Parses colon separated key value pairs from a list of strings. Pairs should be separated by a comma and key and value by a colon, e.g., ['k1:v1', 'k2:v2,k3:v3']. Args: strings: A list of strings. Returns: A dict populated with keys and values from the flag. """ pairs = {} for pair in [kv for s in strings for kv in s.split(',')]: try: key, value = pair.split(':', 1) pairs[key] = value except ValueError: logging.error('Bad key value pair format. Skipping "%s".', pair) continue return pairs def GetProvidedCommandLineFlags(): """Return flag names and values that were specified on the command line. Returns: A dictionary of provided flags in the form: {flag_name: flag_value}. """ def _GetSerializeableValue(v): if isinstance(v, enum.Enum): return v.name return v return { k: _GetSerializeableValue(FLAGS[k].value) for k in FLAGS if FLAGS[k].present }