6

I want to use xgboost in python for my upcoming model. However since our production system is in SAS, I am trying to extract decision rules from xgboost and then write a SAS scoring code to implement this model in SAS environment.

I have gone through multiple links to this. below are some of them:

How to extract decision rules (features splits) from xgboost model in python3?

xgboost deployment

The above two links are of great help specifically the code given by Shiutang-Li for xgboost deployment. However, my predicted scores are not exactly matching.

below is the code I have tried so far:

import numpy as np
import pandas as pd
import xgboost as xgb
from sklearn.grid_search import GridSearchCV
%matplotlib inline
import graphviz
from graphviz import Digraph

#Read the sample iris data:
iris =pd.read_csv("C:\\Users\\XXXX\\Downloads\\Iris.csv")
#Create dependent variable:
iris.loc[iris["class"] != 2,"class"] = 0
iris.loc[iris["class"] == 2,"class"] = 1

#Select independent and dependent variable:
X = iris[["sepal_length","sepal_width","petal_length","petal_width"]]
Y = iris["class"]

xgdmat = xgb.DMatrix(X, Y) # Create our DMatrix to make XGBoost more efficient

#Build the sample xgboost Model:

our_params = {'eta': 0.1, 'seed':0, 'subsample': 0.8, 'colsample_bytree': 0.8, 
             'objective': 'binary:logistic', 'max_depth':3, 'min_child_weight':1} 
Base_Model = xgb.train(our_params, xgdmat, num_boost_round = 10)

#Below code reads the dump file created by xgboost and writes a scoring code in SAS:

import re
def string_parser(s):
    if len(re.findall(r":leaf=", s)) == 0:
        out  = re.findall(r"[\w.-]+", s)
        tabs = re.findall(r"[\t]+", s)
        if (out[4] == out[8]):
            missing_value_handling = (" or missing(" + out[1] + ")")
        else:
            missing_value_handling = ""

        if len(tabs) > 0:
            return (re.findall(r"[\t]+", s)[0].replace('\t', '    ') + 
                    '        if state = ' + out[0] + ' then do;\n' +
                    re.findall(r"[\t]+", s)[0].replace('\t', '    ') +
                    '            if ' + out[1] + ' < ' + out[2] + missing_value_handling +
                    ' then state = ' + out[4] + ';' +  ' else state = ' + out[6] + ';\nend;' ) 
        else:
            return ('        if state = ' + out[0] + ' then do;\n' +
                    '            if ' + out[1] + ' < ' + out[2] + missing_value_handling +
                    ' then state = ' + out[4] + ';' +  ' else state = ' + out[6] + ';\nend;' )
    else:
        out = re.findall(r"[\w.-]+", s)
        return (re.findall(r"[\t]+", s)[0].replace('\t', '    ') + 
                '        if state = ' + out[0] + ' then\n    ' +
                re.findall(r"[\t]+", s)[0].replace('\t', '    ') + 
                '        value = value + (' + out[2] + ') ;\n')

def tree_parser(tree, i):
    return ('state = 0;\n'
             + "".join([string_parser(tree.split('\n')[i]) for i in range(len(tree.split('\n'))-1)]))

def model_to_sas(model, out_file):
    trees = model.get_dump()
    result = ["value = 0;\n"]
    with open(out_file, 'w') as the_file:
        for i in range(len(trees)):
            result.append(tree_parser(trees[i], i))
        the_file.write("".join(result))
        the_file.write("\nY_Pred1 = 1/(1+exp(-value));\n")
        the_file.write("Y_Pred0 = 1 - Y_pred1;") 

Call the above above module to create a SAS scoring code:

model_to_sas(Base_Model, 'xgb_scr_code.sas')

Unfortunately I can't provide full SAS code generated by the above module. However, please find below the SAS code if we build the model using only one tree code:

value = 0;
state = 0;
if state = 0 then
    do;
        if sepal_width < 2.95000005 or missing(sepal_width) then state = 1;
        else state = 2;
    end;
if state = 1 then
    do;
        if petal_length < 4.75 or missing(petal_length) then state = 3;
        else state = 4;
    end;

if state = 3 then   value = value + (0.1586207);
if state = 4 then   value = value + (-0.127272725);
if state = 2 then
    do;
        if petal_length < 3 or missing(petal_length) then state = 5;
        else state = 6;
    end;
if state = 5 then   value = value + (-0.180952385);
if state = 6 then
    do;
        if petal_length < 4.75 or missing(petal_length) then state = 7;
        else state = 8;
    end;
if state = 7 then   value = value + (0.142857149);
if state = 8 then   value = value + (-0.161290333);

Y_Pred1 = 1/(1+exp(-value));
Y_Pred0 = 1 - Y_pred1;

Below is the dump file output for 1st tree:

booster[0]:
    0:[sepal_width<2.95000005] yes=1,no=2,missing=1
        1:[petal_length<4.75] yes=3,no=4,missing=3
            3:leaf=0.1586207
            4:leaf=-0.127272725
        2:[petal_length<3] yes=5,no=6,missing=5
            5:leaf=-0.180952385
            6:[petal_length<4.75] yes=7,no=8,missing=7
                7:leaf=0.142857149
                8:leaf=-0.161290333

So basically, what I am trying to do is that, save the node number in the variable "state" and accordingly access the leaf nodes (which I learned from the article by Shiutang-Li mentioned in the above links).

Here is the issue I am facing:

For up to approx 40 trees, the predicted score is exactly matching. For example please see below:

Case 1:

Predicted value using python for 10 trees:

Y_pred1 = Base_Model.predict(xgdmat)

print("Development- Y_Actual: ",np.mean(Y)," Y predicted: ",np.mean(Y_pred1))

Output:

Average- Y_Actual:  0.3333333333333333  Average Y predicted:  0.4021197

Predicted value using SAS for 10 trees:

Average Y predicted:  0.4021197

Case 2:

Predicted value using python for 100 trees:

Y_pred1 = Base_Model.predict(xgdmat)

print("Development- Y_Actual: ",np.mean(Y)," Y predicted: ",np.mean(Y_pred1))

Output:

Average- Y_Actual:  0.3333333333333333  Average Y predicted:  0.33232176

Predicted value using SAS for 100 trees:

Average Y predicted:  0.3323159

As you can see the scores are not exactly matching(matches up to 4 decimal points) for 100 trees. Also, I have tried this on large files where the difference in scores is quite high i.e more than 10% deviation in scores.

Could anyone let me point towards any error in my code, so that scores can match exactly. Below is my some queries:

1)Is my score calculation correct.

2)I found something related to the gamma(regularization term). Does it impact the way xgboost calculates the scores using leaf values.

3)Does the leaf values given by dump files will have any rounding off, thus creating this issue

Also, I would appreciate any other method to do this task apart from parsing the dump file.

P.S.: I only have SAS EG and do not have access to SAS EM or SAS IML.

Ved
  • 93
  • 2
  • 4

5 Answers5

1

I had similar experience with getting matching scores.
What i understand is that the scoring may stop early unless you fix the ntree_limit option to match the n_estimators that you used during model fitting.

df['score']= xgclfpkl.predict(df[xg_features], ntree_limit=500)

After I started using ntree_limit, I started getting matching scores.

Markus G.
  • 1,620
  • 2
  • 25
  • 49
KKane
  • 11
  • 1
  • Hi KKane, Thanks a lot for the comment as I am now stucked in this. However, I didn't get what you said. Did you mean XGboost early stops automatically the number of trees ? Could you please help me understand this. Also it seems that you have solved this issue . Could you please help me by posting the code here . Thanks a lot for the help. – Ved Apr 19 '19 at 10:09
  • Hi KKane, Could you please reply to my above query – Ved Apr 24 '19 at 12:04
0

I have a similar experience that requires to extract xgboost scoring code from R to SAS.

Initially, I faced the same issue as you have here, that is, in smaller trees, there's no much difference between the scores in R and SAS, once the number of the trees goes up to 100 or beyond, I began to observe the discrepancies.

I did 3 things to narrow the discrepancies:

  1. Make sure the missing group went in the right direction, you would need to be explicit. Otherwise SAS would treat the missing value as the smallest value than all numbers. The rule should be something like below in SAS.

if sepal_width > 2.95000005 or missing(sepal_width) then state = 1;else state = 2;
or
if sepal_width <= 2.95000005 and ~missing(sepal_width) then state = 1;else state = 2;

  1. I used an R package called float to make the score has more decimal places. as.numeric(float::fl(Quality))

  2. Make sure the SAS data were in the same shape as the data you trained in Python.

Hope the above helps.

DDZR
  • 1
  • 1
  • Hi DDZR, Thank you for your reply. I have already taken care of point 1 and point 3. However I did not understood point 2. I am reading the floating point value as it is from the dump file, How can I increase the decimal points here? – Ved May 25 '20 at 10:26
0

I had a little look into incorporating this into my own code.

I found there was a small issue around missing handling.

It seems to work fine where you have a logic like

if petal_length < 3 or missing(petal_length) then state = 5;
        else state = 6;

but say the missing group should go to state 6 instead of state 5. Then you get code like this:

if petal_length < 3 then state = 5;
        else state = 6;

What state does petal_length = missing (.) get in this instance? Well here it still goes to state 5 (rather than the intended state 6) as in SAS missing is classed as less than any number.

To fix this you could assign all missing values to 999999999999999 (picking a high number as the XGBoost format always uses less than (<)) and then replace

missing_value_handling = (" or missing(" + out[1] + ")")

with

missing_value_handling = (" or " + out[1] + "=999999999999999 ")

in your string_parser.

David Buck
  • 3,752
  • 35
  • 31
  • 35
  • Thanks a lot david for the suggestion. I have made these changes. However, in my data there is no missing observations and the issue remains the same i.e. the scores are still not matching after 4th decimal point. The code is reading the dump file correctly, therefore it can only happen if there is a decimal point error in the dump file itself. However, I certainly do not believe that this can be the case and therefore not sure what to do. Could you please suggest what might be causing this. Thanks a lot for your help – Ved May 25 '20 at 08:12
0

Couple of points-

Firstly, the regular expression to match the leaf return value does not capture the "e-decimals" scientific notation (default) in the dump. Explicit example (second one is the correct modification!)-

s = '3:leaf=9.95066429e-09'
out = re.findall(r"[\d.-]+", s)
out2 = re.findall(r"-?[\d.]+(?:e-?\d+)?", s)
out2,out

(Easy to fix but not to spot as exactly one leaf was affected in my model!)

Secondly, the question is on binary but in multi-class targets there are separate trees for each class in the dump, so you have T*C trees total, where T is the number of boost rounds and C is the number of classes. For class c (in {0,1,...,C-1}) you need to evaluate (and sum terminal leaves for) trees i*C +c for i = 0,...,T-1. Then softmax it to match the predictions from xgb.

P.Windridge
  • 246
  • 2
  • 11
  • Thanks a lot P.Windridge for the correction regarding reading of scientific notations from the dump file. I have made the required changes. However, in the example stated above there was no error related scientific notations and issue remains the same. Could you please let me know how to solve this. PS: I do no have any missing value in the data. – Ved May 25 '20 at 08:16
0

Below is the code fragment which prints all the rules extracted from the booster trees from xgboost model. The below code assumes that you already have a model packaged into a pickle file.

import pandas as pd
import numpy as np
import pickle
import networkx as nx

_model = pickle.load(open(MODEL_FILE, "rb"))

df = _model._Booster.trees_to_dataframe()
df['_missing'] = df.apply(
    lambda x: 'Yes' if pd.notnull(x['Missing']) and pd.notnull(x['Yes']) and pd.notnull(x['No']) and x['Missing'] == x[
        'Yes'] else 'No', axis = 1)

G = nx.DiGraph()
G.add_nodes_from(df.ID.tolist())

yes_edges = df[['ID', 'Yes', 'Feature', 'Split', '_missing']].dropna()
yes_edges['label'] = yes_edges.apply(
    lambda x: "({feature} < {value:.4f} or {feature} is null)".format(feature = x['Feature'], value = x['Split']) if x['_missing'] == 'Yes'
    else "({feature} < {value:.4f})".format(feature = x['Feature'], value = x['Split']),
    axis = 1
)

no_edges = df[['ID', 'No', 'Feature', 'Split', '_missing']].dropna()
no_edges['label'] = no_edges.apply(
    lambda x: "({feature} >= {value:.4f} or {feature} is null)".format(feature = x['Feature'], value = x['Split']) if x['_missing'] == 'No'
    else "({feature} >= {value:.4f})".format(feature = x['Feature'], value = x['Split']),
    axis = 1
)

for v in yes_edges.values:
    G.add_edge(v[0], v[1], feature = v[2], expr = v[5])

for v in no_edges.values:
    G.add_edge(v[0], v[1], feature = v[2], expr = v[5])

leaf_node_score_values = {i[0]: i[1] for i in df[df.Feature == 'Leaf'][['ID', 'Gain']].values}
nodeID_to_tree_map = {i[1]: i[0] for i in df[['Tree', 'ID']].values}

roots = []
leaves = []
for node in G.nodes:
    if G.in_degree(node) == 0:  # it's a root
        roots.append(node)
    elif G.out_degree(node) == 0:  # it's a leaf
        leaves.append(node)

paths = []
for root in roots:
    for leaf in leaves:
        for path in nx.all_simple_paths(G, root, leaf):
            paths.append(path)

rules = []
temp = []
for path in paths:
    parts = []
    for i in range(len(path) - 1):
        parts.append(G[path[i]][path[i + 1]]['expr'])
    rules.append(" and ".join(parts))
    temp.append((
        path[0],
        nodeID_to_tree_map.get(path[0]),
        " and ".join(parts),
        leaf_node_score_values.get(path[-1])
    ))

rules_df = pd.DataFrame.from_records(temp, columns = ['node', 'tree', 'rule', 'score'])
rules_df['prob'] = rules_df.apply(lambda x: 1 / (1 + np.exp(-1 * x['score'])), axis = 1)
rules_df['rule_idx'] = rules_df.index
rules_df = rules_df.drop(['node'], axis = 1)

print("n_rules -> {}".format(len(rules_df)))

del G, df, roots, leaves, yes_edges, no_edges, temp, rules

The above code prints every rule in the format as below:

if x>y and a>b and c<d then e
Rakesh Chintha
  • 615
  • 5
  • 6