def parse_tree_model()

in smdebug/xgboost/utils.py [0:0]


def parse_tree_model(booster, iteration, fmap=""):
    """Parse a boosted tree model text dump into a dictionary.

    This function is modified from xgboost.core.Booster.trees_to_dataframe() to
    take a Booster object and output a dictionary rather than a pandas dataframe.

    https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.Booster.trees_to_dataframe
    https://github.com/dmlc/xgboost/blob/release_0.90/python-package/xgboost/core.py#L1580
    """
    tree_ids = []
    node_ids = []
    fids = []
    splits = []
    y_directs = []
    n_directs = []
    missings = []
    gains = []
    covers = []

    trees = booster.get_dump(with_stats=True)
    for i, tree in enumerate(trees):
        if i < iteration:
            continue
        if i > iteration:
            break
        for line in tree.split("\n"):
            arr = line.split("[")
            # Leaf node
            if len(arr) == 1:
                # Last element of line.split is an empy string
                if arr == [""]:
                    continue
                # parse string
                parse = arr[0].split(":")
                stats = re.split("=|,", parse[1])

                # append to lists
                tree_ids.append(i)
                node_ids.append(int(re.findall(r"\b\d+\b", parse[0])[0]))
                fids.append("Leaf")
                splits.append(float("NAN"))
                y_directs.append(float("NAN"))
                n_directs.append(float("NAN"))
                missings.append(float("NAN"))
                gains.append(float(stats[1]))
                covers.append(float(stats[3]))
            # Not a Leaf Node
            else:
                # parse string
                fid = arr[1].split("]")
                parse = fid[0].split("<")
                stats = re.split("=|,", fid[1])

                # append to lists
                tree_ids.append(i)
                node_ids.append(int(re.findall(r"\b\d+\b", arr[0])[0]))
                fids.append(parse[0])
                splits.append(float(parse[1]))
                str_i = str(i)
                y_directs.append(str_i + "-" + stats[1])
                n_directs.append(str_i + "-" + stats[3])
                missings.append(str_i + "-" + stats[5])
                gains.append(float(stats[7]))
                covers.append(float(stats[9]))

    ids = [str(t_id) + "-" + str(n_id) for t_id, n_id in zip(tree_ids, node_ids)]

    key_to_array = {
        "Tree": np.array(tree_ids, dtype=np.dtype("int")),
        "Node": np.array(node_ids, dtype=np.dtype("int")),
        "ID": np.array(ids, dtype=np.dtype("U")),
        "Feature": np.array(fids, dtype=np.dtype("U")),
        "Split": np.array(splits, dtype=np.dtype("float")),
        "Yes": np.array(y_directs, dtype=np.dtype("U")),
        "No": np.array(n_directs, dtype=np.dtype("U")),
        "Missing": np.array(missings, dtype=np.dtype("U")),
        "Gain": np.array(gains, dtype=np.dtype("float")),
        "Cover": np.array(covers, dtype=np.dtype("float")),
    }
    # XGBoost's trees_to_dataframe() method uses
    # df.sort_values(['Tree', 'Node']).reset_index(drop=True) to sort the
    # node ids. The following achieves the same result without using pandas.
    indices = key_to_array["Node"].argsort()
    result = {key: arr[indices] for key, arr in key_to_array.items()}
    return result