XAI

The module of xai contains methods that help explain a model decisions.

In order for this module to work properly, Graphiz must be installed. In linux based operating systems use:

sudo apt-get install graphviz

Or using conda:

conda install graphviz

For more information see here.

Draw Tree

Deprecated since version 1.6.4: Use sklearn.tree.plot_tree instead

xai.draw_tree(tree: sklearn.tree._classes.BaseDecisionTree, feature_names: Optional[List[str]] = None, class_names: Optional[List[str]] = None, *, ax: Optional[matplotlib.axes._axes.Axes] = None, **kwargs) → matplotlib.axes._axes.Axes[source]

Receives a decision tree and return a plot graph of the tree for easy interpretation.

Parameters:
  • tree – decision tree.
  • feature_names – the features names.
  • class_names – the classes names or labels.
  • ax – Axes object to draw the plot onto, otherwise uses the current Axes.
  • kwargs

    other keyword arguments

    All other keyword arguments are passed to matplotlib.axes.Axes.pcolormesh().

Returns:

Returns the Axes object with the plot drawn onto it.

Code Example

In following examples we are going to use the iris dataset from scikit-learn. so firstly let’s import it:

from sklearn import datasets


iris = datasets.load_iris()
x = iris.data
y = iris.target

We’ll create a simple decision tree classifier and plot it:

from matplotlib import pyplot
from sklearn.tree import DecisionTreeClassifier

from ds_utils.xai import draw_tree


# Create decision tree classifier object
clf = DecisionTreeClassifier(random_state=0)

# Train model
clf.fit(x, y)

draw_tree(clf, iris.feature_names, iris.target_names)
pyplot.show()

And the following image will be shown:

Decision Tree Visualization

Draw Dot Data

xai.draw_dot_data(dot_data: str, *, ax: Optional[matplotlib.axes._axes.Axes] = None, **kwargs) → matplotlib.axes._axes.Axes[source]

Receives a Graphiz’s Dot language string and return a plot graph of the data.

Parameters:
  • dot_data – Graphiz’s Dot language string.
  • ax – Axes object to draw the plot onto, otherwise uses the current Axes.
  • kwargs

    other keyword arguments

    All other keyword arguments are passed to matplotlib.axes.Axes.pcolormesh().

Returns:

Returns the Axes object with the plot drawn onto it.

Code Example

We’ll create a simple diagram and plot it:

from matplotlib import pyplot

from ds_utils.xai import draw_dot_data


dot_data = "digraph D{\n" \
           "\tA [shape=diamond]\n" \
           "\tB [shape=box]\n" \
           "\tC [shape=circle]\n" \
           "\n" \
           "\tA -> B [style=dashed, color=grey]\n" \
           "\tA -> C [color=\"black:invis:black\"]\n" \
           "\tA -> D [penwidth=5, arrowhead=none]\n" \
           "\n" \
           "}"

draw_dot_data(dot_data)
pyplot.show()

And the following image will be shown:

Diagram Visualization

Generate Decision Paths

xai.generate_decision_paths(classifier: sklearn.tree._classes.BaseDecisionTree, feature_names: Optional[List[str]] = None, class_names: Optional[List[str]] = None, tree_name: Optional[str] = None, indent_char: str = '\t') → str[source]

Receives a decision tree and return the underlying decision-rules (or ‘decision paths’) as text (valid python syntax). Original code

Parameters:
  • classifier – decision tree.
  • feature_names – the features names.
  • class_names – the classes names or labels.
  • tree_name – the name of the tree (function signature).
  • indent_char – the character used for indentation.
Returns:

textual representation of the decision paths of the tree.

Code Example

We’ll create a simple decision tree classifier and print it:

from sklearn.tree import DecisionTreeClassifier

from ds_utils.xai import generate_decision_paths


# Create decision tree classifier object
clf = DecisionTreeClassifier(random_state=0, max_depth=3)

# Train model
clf.fit(x, y)
print(generate_decision_paths(clf, iris.feature_names, iris.target_names.tolist(),
                     "iris_tree"))

The following text will be printed:

def iris_tree(petal width (cm), petal length (cm)):
    if petal width (cm) <= 0.8000:
        # return class setosa with probability 0.9804
        return ("setosa", 0.9804)
    else:  # if petal width (cm) > 0.8000
        if petal width (cm) <= 1.7500:
            if petal length (cm) <= 4.9500:
                # return class versicolor with probability 0.9792
                return ("versicolor", 0.9792)
            else:  # if petal length (cm) > 4.9500
                # return class virginica with probability 0.6667
                return ("virginica", 0.6667)
        else:  # if petal width (cm) > 1.7500
            if petal length (cm) <= 4.8500:
                # return class virginica with probability 0.6667
                return ("virginica", 0.6667)
            else:  # if petal length (cm) > 4.8500
                # return class virginica with probability 0.9773
                return ("virginica", 0.9773)

Plot Features` Importance

xai.plot_features_importance(feature_names: List[str], feature_importance: List[float], *, ax: Optional[matplotlib.axes._axes.Axes] = None, **kwargs) → matplotlib.axes._axes.Axes[source]

plot feature importance as a bar chart.

Parameters:
  • feature_names – strings list of feature names
  • feature_importance – float list of feature importance
  • ax – Axes object to draw the plot onto, otherwise uses the current Axes.
  • kwargs

    other keyword arguments

    All other keyword arguments are passed to matplotlib.axes.Axes.pcolormesh().

Returns:

Returns the Axes object with the plot drawn onto it.

Code Example

For this example I created a dummy data set. You can find the data at the resources directory in the packages tests folder.

Let’s see how to use the code:

import pandas

from matplotlib import pyplot
from sklearn.preprocessing import OneHotEncoder
from sklearn.tree import DecisionTreeClassifier

from ds_utils.xai import plot_features_importance


data_1M = pandas.read_csv(path/to/dataset)
target = data_1M["x12"]
categorical_features = ["x7", "x10"]
for i in range(0, len(categorical_features)):
    enc = OneHotEncoder(sparse=False, handle_unknown="ignore")
    enc_out = enc.fit_transform(data_1M[[categorical_features[i]]])
    for j in range(0, len(enc.categories_[0])):
        data_1M[categorical_features[i] + "_" + enc.categories_[0][j]] = enc_out[:, j]
features = data_1M.columns.to_list()
features.remove("x12")
features.remove("x7")
features.remove("x10")

clf = DecisionTreeClassifier(random_state=42)
clf.fit(data_1M[features], target)
plot_features_importance(features, clf.feature_importances_)

pyplot.show()

And the following image will be shown:

Plot Features Importance