def uplift_tree_plot()

in causalml/inference/tree/plot.py [0:0]


def uplift_tree_plot(decisionTree, x_names):
    '''
    Convert the tree to dot graph for plots.

    Args
    ----

    decisionTree : object
        object of DecisionTree class

    x_names : list
        List of feature names

    Returns
    -------
    Dot class representing the tree graph.
    '''

    # Column Heading
    dcHeadings = {}
    for i, szY in enumerate(x_names + ['treatment_group_key']):
        szCol = 'Column %d' % i
        dcHeadings[szCol] = str(szY)

    dcNodes = defaultdict(list)
    """Plots the obtained decision tree. """

    def toString(iSplit, decisionTree, bBranch, szParent="null", indent='', indexParent=0, upliftScores=list()):
        if decisionTree.results is not None:  # leaf node
            lsY = []
            for szX, n in decisionTree.results.items():
                lsY.append('%s:%.2f' % (szX, n))
            dcY = {"name": "%s" % ', '.join(lsY), "parent": szParent}
            dcSummary = decisionTree.summary
            upliftScores += [dcSummary['matchScore']]
            dcNodes[iSplit].append(['leaf', dcY['name'], szParent, bBranch,
                                    str(-round(float(decisionTree.summary['impurity']), 3)), dcSummary['samples'],
                                    dcSummary['group_size'], dcSummary['upliftScore'], dcSummary['matchScore'],
                                    indexParent])
        else:
            szCol = 'Column %s' % decisionTree.col
            if szCol in dcHeadings:
                szCol = dcHeadings[szCol]
            if isinstance(decisionTree.value, int) or isinstance(decisionTree.value, float):
                decision = '%s >= %s' % (szCol, decisionTree.value)
            else:
                decision = '%s == %s' % (szCol, decisionTree.value)

            indexOfLevel = len(dcNodes[iSplit])
            toString(iSplit + 1, decisionTree.trueBranch, True, decision, indent + '\t\t', indexOfLevel, upliftScores)
            toString(iSplit + 1, decisionTree.falseBranch, False, decision, indent + '\t\t', indexOfLevel, upliftScores)
            dcSummary = decisionTree.summary
            upliftScores += [dcSummary['matchScore']]
            dcNodes[iSplit].append([iSplit + 1, decision, szParent, bBranch,
                                    str(-round(float(decisionTree.summary['impurity']), 3)), dcSummary['samples'],
                                    dcSummary['group_size'], dcSummary['upliftScore'], dcSummary['matchScore'],
                                    indexParent])

    upliftScores = list()
    toString(0, decisionTree, None, upliftScores=upliftScores)

    upliftScoreToColor = dict()
    try:
        # calculate colors for nodes based on uplifts
        minUplift = min(upliftScores)
        maxUplift = max(upliftScores)
        upliftLevels = [(uplift-minUplift)/(maxUplift-minUplift) for uplift in upliftScores]  # min max scaler
        baseUplift = float(decisionTree.summary.get('matchScore'))
        baseUpliftLevel = (baseUplift - minUplift) / (maxUplift - minUplift)  # min max scaler normalization
        white = np.array([255., 255., 255.])
        blue = np.array([31., 119., 180.])
        green = np.array([0., 128., 0.])
        for i, upliftLevel in enumerate(upliftLevels):
            if upliftLevel >= baseUpliftLevel:  # go blue
                color = upliftLevel * blue + (1 - upliftLevel) * white
            else:  # go green
                color = (1 - upliftLevel) * green + upliftLevel * white
            color = [int(c) for c in color]
            upliftScoreToColor[upliftScores[i]] = ('#%2x%2x%2x' % tuple(color)).replace(' ', '0')  # color code
    except Exception as e:
        print(e)

    lsDot = ['digraph Tree {',
             'node [shape=box, style="filled, rounded", color="black", fontname=helvetica] ;',
             'edge [fontname=helvetica] ;'
             ]
    i_node = 0
    dcParent = {}
    totalSample = int(decisionTree.summary.get('samples'))  # initialize the value with the total sample size at root
    for nSplit in range(len(dcNodes.items())):
        lsY = dcNodes[nSplit]
        indexOfLevel = 0
        for lsX in lsY:
            iSplit, decision, szParent, bBranch, szImpurity, szSamples, szGroup, \
                upliftScore, matchScore, indexParent = lsX

            sampleProportion = round(int(szSamples)*100./totalSample, 1)
            if type(iSplit) == int:
                szSplit = '%d-%d' % (iSplit, indexOfLevel)
                dcParent[szSplit] = i_node
                lsDot.append('%d [label=<%s<br/> impurity %s<br/> total_sample %s (%s&#37;)<br/>group_sample %s <br/> '
                             'uplift score: %s <br/> uplift p_value %s <br/> '
                             'validation uplift score %s>, fillcolor="%s"] ;' % (
                                 i_node, decision.replace('>=', '&ge;').replace('?', ''), szImpurity, szSamples,
                                 str(sampleProportion), szGroup, str(upliftScore[0]), str(upliftScore[1]),
                                 str(matchScore), upliftScoreToColor.get(matchScore, '#e5813900')
                             ))
            else:
                lsDot.append('%d [label=< impurity %s<br/> total_sample %s (%s&#37;)<br/>group_sample %s <br/> '
                             'uplift score: %s <br/> uplift p_value %s <br/> validation uplift score %s <br/> '
                             'mean %s>, fillcolor="%s"] ;' % (
                                 i_node, szImpurity, szSamples, str(sampleProportion), szGroup, str(upliftScore[0]),
                                 str(upliftScore[1]), str(matchScore), decision,
                                 upliftScoreToColor.get(matchScore, '#e5813900')
                             ))

            if szParent != 'null':
                if bBranch:
                    szAngle = '45'
                    szHeadLabel = 'True'
                else:
                    szAngle = '-45'
                    szHeadLabel = 'False'
                szSplit = '%d-%d' % (nSplit, indexParent)
                p_node = dcParent[szSplit]
                if nSplit == 1:
                    lsDot.append('%d -> %d [labeldistance=2.5, labelangle=%s, headlabel="%s"] ;' % (p_node,
                                                                                                    i_node, szAngle,
                                                                                                    szHeadLabel))
                else:
                    lsDot.append('%d -> %d ;' % (p_node, i_node))
            i_node += 1
            indexOfLevel += 1
    lsDot.append('}')
    dot_data = '\n'.join(lsDot)
    graph = pydotplus.graph_from_dot_data(dot_data)
    return graph