XAI (Explainable AI)

The xai module contains methods that help explain model decisions.

For this module to work properly, Graphviz must be installed. Use the following commands based on your operating system:

For Linux-based systems:

sudo apt-get install graphviz

For Windows:

choco install graphviz

For macOS:

brew install graphviz

Using conda:

conda install graphviz

For more information, see the Graphviz download page.

Draw Tree

Deprecated since version 1.6.4: Use sklearn.tree.plot_tree instead

The draw_tree function visualizes a decision tree classifier, making it easier to understand the tree’s structure and decision-making process. This can be particularly useful for model interpretation and debugging.

ds_utils.xai.draw_tree(tree: BaseDecisionTree, feature_names: List[str] | None = None, class_names: List[str] | None = None, *, ax: Axes | None = None, **kwargs) Axes[source]

Plot a graph of the decision tree for easy interpretation.

Parameters:
  • tree – Decision tree.

  • feature_names – List of feature names.

  • class_names – List of class names or labels.

  • ax – Axes object to draw the plot onto, otherwise uses the current Axes.

  • kwargs – Additional keyword arguments passed to matplotlib.axes.Axes.imshow().

Returns:

Axes object with the plot drawn onto it.

Code Example

In the following example, we’ll use the iris dataset from scikit-learn:

from sklearn import datasets

iris = datasets.load_iris()
X = iris.data  # Use uppercase 'X' for feature matrix
y = iris.target

Now, let’s create a simple decision tree classifier and plot it:

import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier
from ds_utils.xai import draw_tree

# Create decision tree classifier object
clf = DecisionTreeClassifier(random_state=0)

# Train model
clf.fit(X, y)

# Draw the tree
draw_tree(clf, iris.feature_names, iris.target_names)
plt.show()

The following image will be displayed:

Decision Tree Visualization

Draw Dot Data

The draw_dot_data function visualizes graph structures defined in DOT language. This is useful for creating custom graph visualizations, including decision trees, flowcharts, or any other graph-based representations.

ds_utils.xai.draw_dot_data(dot_data: str, *, ax: Axes | None = None, **kwargs) Axes[source]

Plot a graph from Graphviz’s Dot language string.

Parameters:
  • dot_data – Graphviz’s Dot language string. Use sklearn.tree.export_graphviz to generate the dot data string.

  • ax – Axes object to draw the plot onto, otherwise uses the current Axes.

  • kwargs – Additional keyword arguments passed to matplotlib.axes.Axes.imshow().

Returns:

Axes object with the plot drawn onto it.

Raises:

ValueError – If the dot_data is empty or invalid.

Code Example

Let’s create a simple diagram and plot it:

import matplotlib.pyplot as plt
from ds_utils.xai import draw_dot_data

dot_data = """
digraph D {
    A [shape=diamond]
    B [shape=box]
    C [shape=circle]

    A -> B [style=dashed, color=grey]
    A -> C [color="black:invis:black"]
    A -> D [penwidth=5, arrowhead=none]
}
"""

draw_dot_data(dot_data)
plt.show()

The following image will be displayed:

Diagram Visualization

Generate Decision Paths

Deprecated since version 1.8.0: Use sklearn.tree.export_text instead

ds_utils.xai.generate_decision_paths(classifier: BaseDecisionTree, feature_names: List[str] | None = None, class_names: List[str] | None = None, tree_name: str | None = None, indent_char: str = '\t') str

Plot Feature Importance

The plot_features_importance function visualizes the importance of different features in a machine learning model. This is crucial for understanding which features have the most significant impact on the model’s predictions, aiding in feature selection and model interpretation.

ds_utils.xai.plot_features_importance(feature_names: ndarray | List[str], feature_importance: ndarray | List[float], *, ax: Axes | None = None, **kwargs) Axes[source]

Plot feature importance as a bar chart.

Parameters:
  • feature_names – Numpy array or list of feature names with shape (n_features,).

  • feature_importance – Numpy array or list of feature importance values with shape (n_features,).

  • ax – Axes object to draw the plot onto, otherwise uses the current Axes.

  • kwargs – Additional keyword arguments passed to matplotlib.axes.Axes.bar().

Returns:

Axes object with the plot drawn onto it.

Raises:

ValueError – If feature_names and feature_importance have different lengths or invalid input.

Code Example

For this example, we’ll use a dummy dataset. You can find the data in the resources directory of the package’s tests folder.

Here’s how to use the code:

import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import OneHotEncoder
from sklearn.tree import DecisionTreeClassifier
from ds_utils.xai import plot_features_importance

# Load the dataset
data_1M = pd.read_csv('path/to/dataset.csv')
target = data_1M["x12"]
categorical_features = ["x7", "x10"]

# Perform one-hot encoding for categorical features
for feature in categorical_features:
    enc = OneHotEncoder(sparse=False, handle_unknown="ignore")
    enc_out = enc.fit_transform(data_1M[[feature]])
    for i, category in enumerate(enc.categories_[0]):
        data_1M[f"{feature}_{category}"] = enc_out[:, i]

# Prepare feature list
features = [col for col in data_1M.columns if col not in ["x12", "x7", "x10"]]

# Create and train the classifier
clf = DecisionTreeClassifier(random_state=42)
clf.fit(data_1M[features], target)

# Plot feature importance
plot_features_importance(features, clf.feature_importances_)
plt.show()

In this example:

  • x12 is the target variable we’re trying to predict.

  • x7 and x10 are categorical features that we one-hot encode.

  • The remaining columns (x1, x2, x3, etc.) are numerical features.

  • After one-hot encoding, we create a list of all features, excluding the original categorical columns and the target variable.

  • We then train a decision tree classifier and plot the importance of each feature.

The following image will be displayed:

Plot Feature Importance