odps/ml/expr/models/pmml.py (476 lines of code) (raw):

#!/usr/bin/env python # -*- coding: utf-8 -*- # Copyright 1999-2022 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import import re import sys import logging from collections import namedtuple from ....compat import six, ElementTree as ET, reduce from ....utils import require_package PMML_VOLUME = 'pyodps_volume' logger = logging.getLogger(__name__) class PmmlRepr(object): def _repr(self): return None def __repr__(self): r = self._repr() if r is None: return '' elif not six.PY2: return r else: return r.encode('utf-8') if isinstance(r, six.text_type) else r class PmmlPredictor(PmmlRepr): def __init__(self, element): self._element = element @property def coefficient(self): return self._element.attrib.get('coefficient') @property def name(self): return self._element.attrib.get('name') class PmmlNumericPredictor(PmmlPredictor): @property def exponent(self): exp = self._element.attrib.get('exponent') return None if exp is None or float(exp) == 1.0 else exp def _repr(self): if float(self.coefficient) == 0.0: return None if self.exponent: return '%s * %s ** %s' % (self.coefficient, self.name, self.exponent) else: return '%s * %s' % (self.coefficient, self.name) def _repr_html_(self): if float(self.coefficient) == 0.0: return None if self.exponent: return '%s * %s<sup>%s</sup>' % (self.coefficient, self.name, self.exponent) else: return '%s * %s' % (self.coefficient, self.name) class PmmlCategoricalPredictor(PmmlPredictor): @property def value(self): return self._element.attrib.get('value') def _repr(self): return '%s * I(%s = %s)' % (self.coefficient, self.name, self.value) def _repr_html_(self): return self._repr() class PmmlRegressionTable(PmmlRepr): def __init__(self, element, target_field='y'): self._element = element self._target_field = target_field @property def intercept(self): intercept = self._element.attrib.get('intercept') return intercept if float(intercept) != 0.0 else None @property def target_category(self): return self._element.attrib.get('targetCategory') def _numeric_predictors(self): return [PmmlNumericPredictor(el) for el in self._element.findall('NumericPredictor')] def _categorical_predictors(self): return [PmmlCategoricalPredictor(el) for el in self._element.findall('CategoricalPredictor')] def predictors(self): return self._numeric_predictors() + self._categorical_predictors() def _build_expr(self, rf=repr): expr_parts = [] if self.intercept: expr_parts.append(self.intercept) predictors = self.predictors() preds = [r for r in (rf(pred) for pred in predictors) if r] expr_parts.extend([pred for pred in preds if pred]) if len(predictors) == 0: return '' value_part = re.sub(r' *\+ *- *', ' - ', ' + '.join(expr_parts)) return self._target_field + ' = ' + value_part def _repr(self): expr = self._build_expr() if self.target_category and expr: return 'Target: %s\n ' % self.target_category + expr else: return expr def _repr_html_(self): expr = self._build_expr(rf=lambda v: v._repr_html_()) if not expr: return '' elif self.target_category: return '<div style="font-weight: bold">Target: %s</div><div style="text-indent: 30px">%s</div>' % (self.target_category, expr) else: return '<div>%s</div>' % expr class PmmlExpr(PmmlRepr): def __init__(self, s, func): self._str = s self._func = func def __str__(self): return repr(self) def _repr(self): return self._str def __call__(self, *args): return self._func(*args) EXPR_DICT = { 'equal': PmmlExpr(u'=', lambda a, b: a == b), 'notEqual': PmmlExpr(u'≠', lambda a, b: a != b), 'lessThan': PmmlExpr(u'<', lambda a, b: a < b), 'lessOrEqual': PmmlExpr(u'≤', lambda a, b: a <= b), 'greaterThan': PmmlExpr(u'>', lambda a, b: a > b), 'greaterOrEqual': PmmlExpr(u'≥', lambda a, b: a >= b), 'isIn': PmmlExpr('in', lambda a, b: a in b), 'isNotIn': PmmlExpr('not in', lambda a, b: a not in b), } SCORE_SIZE = 16 DETAIL_SIZE = 11 def escape_graphviz(src): target = src for cfr, cto in GV_ESCAPES: target = target.replace(cfr, cto) return target TEXT_TREE_CORNER = '└' TEXT_TREE_FILL = ' ' TEXT_TREE_FORK = '├' TEXT_TREE_HLINE = '─' TEXT_TREE_VLINE = '│' def build_text_tree(root): def write_node_step(writer, node, header, is_last): front_line = (TEXT_TREE_CORNER if is_last else TEXT_TREE_FORK) + TEXT_TREE_HLINE * 2 + ' ' writer.write(header + front_line + node.text + '\n') sub_nodes = list(node.children()) if sub_nodes: if is_last: append_header = TEXT_TREE_FILL * 2 + ' ' else: append_header = TEXT_TREE_VLINE + TEXT_TREE_FILL + ' ' for node in sub_nodes[:-1]: write_node_step(writer, node, header + append_header, False) write_node_step(writer, sub_nodes[-1], header + append_header, True) sio = six.StringIO() sio.write(root.text + '\n') root_children = list(root.children()) for snode in root_children[:-1]: write_node_step(sio, snode, '', False) write_node_step(sio, root_children[-1], '', True) return sio.getvalue() GV_ESCAPES = [ ('<', '&lt;'), ('>', '&gt;'), ] def build_gv_tree(root): counter = [0, ] def write_gv_step(writer, node, parent_key): writer.write(u'{0} {1};\n'.format(parent_key, node.gv_text)) sub_nodes = list(node.children()) if sub_nodes: for node in sub_nodes: counter[0] += 1 struct_str = u'struct%d' % counter[0] write_gv_step(writer, node, struct_str) writer.write(u'%s -> %s;\n' % (parent_key, struct_str)) sio = six.StringIO() sio.write(u'digraph {{\nroot {0};\n'.format(root.gv_text)) root_children = list(root.children()) for snode in root_children: counter[0] += 1 struct_str = u'struct%d' % counter[0] write_gv_step(sio, snode, struct_str) sio.write('root -> %s;\n' % struct_str) sio.write('}\n') return six.text_type(sio.getvalue()) if six.PY2: def unescape_string(s): return s.decode('string-escape') else: def unescape_string(s): return s.encode('utf-8').decode('unicode-escape') def parse_pmml_array(element): if element is None: return None parts = [] sio = six.StringIO() quoted = False last_ch = None for ch in element.text: if ch == '\"': if last_ch == '\\': sio.write('\"') else: quoted = not quoted elif not quoted and (ch == ' ' or ch == '\t'): s = sio.getvalue() if s: parts.append(unescape_string(s)) sio = six.StringIO() else: sio.write(ch) last_ch = ch s = sio.getvalue() if s: parts.append(unescape_string(s)) if 'n' in element.attrib: assert int(element.attrib['n']) == len(parts) array_type = element.attrib['type'].lower() if array_type == 'int': wrapper = int elif array_type == 'real': wrapper = float else: wrapper = lambda v: v return [wrapper(p) for p in parts] def pmml_predicate(name): def _decorator(cls): PmmlPredicate._subclasses[name] = cls return cls return _decorator class PmmlPredicate(PmmlRepr): _subclasses = dict() def __new__(cls, *args, **kwargs): if 'element' in kwargs: element = kwargs['element'] else: element = args[0] if element.tag in cls._subclasses: return object.__new__(cls._subclasses[element.tag]) def __init__(self, element): self._element = element @classmethod def exists(cls, element): return any(element.find(s) is not None for s in six.iterkeys(cls._subclasses)) @classmethod def iterate(cls, element): return (cls(el) for el in reduce(lambda a, b: a + b, (element.findall('./' + c) for c in six.iterkeys(cls._subclasses)), [])) @property def field(self): return self._element.attrib.get('field') @pmml_predicate('SimplePredicate') class PmmlSimplePredicate(PmmlPredicate): @property def operator(self): return self._element.attrib.get('operator') @property def value(self): return self._element.attrib.get('value') def _repr(self): return '`%s` %s %s' % (self.field, EXPR_DICT.get(self.operator, self.operator), self.value) @pmml_predicate('SimpleSetPredicate') class PmmlSimpleSetPredicate(PmmlPredicate): def __init__(self, element): super(PmmlSimpleSetPredicate, self).__init__(element) @property def operator(self): return self._element.attrib.get('booleanOperator') @property def array(self): return parse_pmml_array(self._element.find('Array')) def _repr(self): return '`%s` %s (%s)' % (self.field, EXPR_DICT.get(self.operator, self.operator), ', '.join(repr(v) for v in self.array)) @pmml_predicate('CompoundPredicate') class PmmlCompoundPredicate(PmmlPredicate): @property def operator(self): return self._element.attrib.get('booleanOperator') def predicates(self): return PmmlPredicate.iterate(self._element) def _repr(self): pds = [] for sub_pd in self.predicates(): pd = repr(sub_pd) if isinstance(sub_pd, PmmlCompoundPredicate): pd = '({0})'.format(pd) pds.append(pd) return ' {0} '.format(self.operator).join(pds) PmmlSegmentSummary = namedtuple('PmmlSegmentSummary', 'type id weight') class PmmlSegment(PmmlRepr): def __init__(self, element): if element.tag == 'Segment': self._segment_element = element else: self._segment_element = None @property def segment_id(self): return self._segment_element.attrib.get('id') if self._segment_element is not None else None @property def segment_weight(self): return self._segment_element.attrib.get('weight') if self._segment_element is not None else None @property def segment_summary(self): return PmmlSegmentSummary(type=self.__class__.__name__, id=self.segment_id, weight=self.segment_weight) class PmmlTreeNode(PmmlRepr): def __init__(self, element, text=None): self._element = element self._text = text @property def score(self): return self._element.attrib.get('score') @property def predicate(self): try: return next(PmmlPredicate.iterate(self._element)) except StopIteration: return None @property def text(self): if self._text: return self._text content = '' if 'score' in self._element.attrib: content += 'SCORE = ' + self._element.attrib['score'] + ' ' if PmmlPredicate.exists(self._element): dists = ['%s:%s' % (e.attrib['value'], e.attrib['recordCount']) for e in self._element.findall('./ScoreDistribution')] pstr = repr(self.predicate) expr = 'WHEN %s' % pstr if dists: expr += ' (COUNTS: %s)' % ', '.join(dists) content += expr return content @property def gv_text(self): if self._text: return u'[shape=record,label=<\n {0}\n>]'.format(self._text) label_lines = [] extra_style = '' if 'score' in self._element.attrib: label_lines.append(u'<FONT POINT-SIZE="{0}">{1}</FONT>'.format(SCORE_SIZE, self._element.attrib['score'])) extra_style = u'style=filled,fillcolor=azure2,' if PmmlPredicate.exists(self._element): dists = [u'%s:%s' % (e.attrib['value'], e.attrib['recordCount']) for e in self._element.findall('./ScoreDistribution')] expr = repr(self.predicate) if six.PY2: expr = expr.decode('utf-8') label_lines.append(u'<FONT POINT-SIZE="{0}">{1}</FONT>'.format(DETAIL_SIZE, escape_graphviz(expr))) if dists: label_lines.append(u'<FONT POINT-SIZE="{0}">LABELS: {1}</FONT>'.format(DETAIL_SIZE, ', '.join(dists))) label = u'<br />'.join(label_lines) return u'[shape=record,{0}label=<\n {1}\n>]'.format(extra_style, label) def children(self): return [PmmlTreeNode(el) for el in self._element.findall('./Node')] def _repr(self): return build_text_tree(self) def _repr_gv_(self): return build_gv_tree(self) @require_package('graphviz') def _repr_svg_(self): from graphviz import Source return Source(self._repr_gv_(), encoding='utf-8')._repr_svg_() class PmmlTree(PmmlSegment): def __init__(self, element): super(PmmlTree, self).__init__(element) if element.tag == 'Segment': self._element = element.find('TreeModel') else: self._element = element @property def root(self): return PmmlTreeNode(self._element.find('./Node'), 'ROOT') def _repr(self): return build_text_tree(self.root) def _repr_gv_(self): return build_gv_tree(self.root) @require_package('graphviz') def _repr_svg_(self): from graphviz import Source return Source(self._repr_gv_(), encoding='utf-8')._repr_svg_() class PmmlResult(object): _result_types = [] def __new__(cls, pmml): et = ET.fromstring(re.sub(' xmlns="[^"]+"', '', pmml, count=1)) obj = None if cls is not PmmlResult: obj = object.__new__(cls) else: if not cls._result_types: for c in six.itervalues(globals()): if c is not PmmlResult and isinstance(c, type) and issubclass(c, PmmlResult): cls._result_types.append(c) for c in cls._result_types: if c.adaptable(et): obj = object.__new__(c) break if obj is None: obj = object.__new__(cls) obj.pmml = pmml obj._pmml_element = et return obj @classmethod def adaptable(cls, et): raise NotImplementedError class PmmlRegressionResult(PmmlResult, PmmlRepr): def __init__(self, *_, **__): if self._pmml_element.find('RegressionModel') is not None: self._reg_element = self._pmml_element.find('RegressionModel') elif self._pmml_element.find('MiningModel/Regression') is not None: self._reg_element = self._pmml_element.find('MiningModel/Regression') @classmethod def adaptable(cls, et): if et.find('RegressionModel') is not None: return True elif et.find('MiningModel/Regression') is not None: return True return False @property def target_field(self): v = self._reg_element.attrib.get('targetFieldName') return v if v else 'y' @property def normalization(self): v = self._reg_element.attrib.get('normalizationMethod') return v if v != 'none' else None @property def function(self): return self._reg_element.attrib.get('functionName') def __iter__(self): return (PmmlRegressionTable(e) for e in self._reg_element.findall('RegressionTable')) def _repr(self): sio = six.StringIO() sio.write('Function: %s\n' % self.function) sio.write('Target Field: %s\n' % self.target_field) if self.normalization: sio.write('Normalization: %s\n' % self.normalization) reprs = (repr(v) for v in self) sio.write('\n'.join(v for v in reprs if v)) return sio.getvalue() def _repr_html_(self): sio = six.StringIO() sio.write('<div><span style="font-weight: bold">Function</span>: %s</div>\n' % self.function) sio.write('<div><span style="font-weight: bold">Target Field</span>: %s</div>\n' % self.target_field) if self.normalization: sio.write('<div><span style="font-weight: bold">Normalization</span>: %s</div>\n' % self.normalization) reprs = (v._repr_html_() for v in self) sio.write('\n'.join(v for v in reprs if v)) return sio.getvalue() class PmmlSegmentsResult(PmmlResult): def __init__(self, *_, **__): self._seg_element = self._pmml_element.find('MiningModel/Segmentation') @classmethod def adaptable(cls, et): return et.find('MiningModel/Segmentation') is not None @staticmethod def _segment_to_object(xsegment): if xsegment.find('./TreeModel') is not None: return PmmlTree(xsegment) else: raise ValueError('Unrecognized PMML node.') def __getitem__(self, item): if sys.version_info[:2] < (2, 7): xsegment = None for seg in self._seg_element.findall('Segment'): if seg.attrib.get('id') == str(item): xsegment = seg break else: xsegment = self._seg_element.find('Segment[@id="{0}"]'.format(item)) if xsegment is None: raise KeyError('No segments found in PMML result.') return self._segment_to_object(xsegment) def __iter__(self): return (self._segment_to_object(xseg) for xseg in self._seg_element.findall('./Segment')) def _get_segments_summary(self): return [seg.segment_summary for seg in self if seg.segment_id is not None] def _repr(self): return repr(self._get_segments_summary()) def _repr_html_(self): html_writer = six.StringIO() html_writer.write('<table><tr><th>ID</th><th>Type</th><th>Weight</th></tr>') for seg in self._get_segments_summary(): html_writer.write('<tr><td>{0}</td><td>{1}</td><td>{2}</td></tr>'.format(seg.id, seg.type, seg.weight)) html_writer.write('</table>') return html_writer.getvalue()