"""Data preprocessing utilities."""
from typing import Callable, List, Optional, Union
import warnings
from matplotlib import axes, dates, pyplot as plt, ticker
import numpy as np
from numpy.random import RandomState
import pandas as pd
from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.spatial.distance import squareform
import seaborn as sns
from sklearn.base import TransformerMixin
from sklearn.compose import ColumnTransformer
from sklearn.feature_selection import mutual_info_classif
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OrdinalEncoder
from ds_utils.math_utils import safe_percentile
def _plot_clean_violin_distribution(
series: pd.Series, include_outliers: bool, outlier_iqr_multiplier: float, ax: Optional[axes.Axes] = None, **kwargs
) -> axes.Axes:
"""Plot a violin distribution for a numeric series with optional outlier trimming.
When ``include_outliers`` is False, values outside the IQR fence are removed
before plotting. The fence is defined as
[Q1 - k * IQR, Q3 + k * IQR], where ``k`` is ``outlier_iqr_multiplier``, and
the bounds are clipped to the observed min/max of the series.
:param series: Numeric series to visualize. NA handling is expected upstream.
:param include_outliers: Whether to include values outside the IQR fence.
:param outlier_iqr_multiplier: Multiplier ``k`` used to compute the IQR fence.
:param ax: Matplotlib Axes to draw on. If None, callers should provide one upstream.
:param kwargs: Additional keyword arguments passed to ``seaborn.violinplot``.
:return: The Axes object with the violin plot.
"""
if include_outliers:
series_plot = series.copy()
else:
q1 = series.quantile(0.25)
q3 = series.quantile(0.75)
min_series_value = series.min()
max_series_value = series.max()
iqr = q3 - q1
lower_bound = max(min_series_value, q1 - outlier_iqr_multiplier * iqr)
upper_bound = min(max_series_value, q3 + outlier_iqr_multiplier * iqr)
series_plot = series[(series >= lower_bound) & (series <= upper_bound)].copy()
sns.violinplot(y=series_plot, hue=None, legend=False, ax=ax, **kwargs)
ax.set_xticks([])
ax.set_ylabel("Values")
ax.grid(axis="y", linestyle="--", alpha=0.7)
return ax
[docs]
def visualize_feature(
series: pd.Series,
remove_na: bool = False,
*,
include_outliers: bool = True,
outlier_iqr_multiplier: float = 1.5,
ax: Optional[axes.Axes] = None,
**kwargs,
) -> axes.Axes:
"""Visualize a pandas Series using an appropriate plot based on dtype.
Behavior by dtype:
- Float: draw a violin distribution. If ``include_outliers`` is False, values
outside the IQR fence [Q1 - k*IQR, Q3 + k*IQR] with ``k=outlier_iqr_multiplier``
are trimmed prior to plotting.
- Datetime: draw a line plot of value counts over time (sorted by index).
- Object/categorical/bool/int: draw a count plot. Extremely high-cardinality
series may be reduced to their top categories internally.
:param series: The data series to visualize.
:param remove_na: If True, plot with NA values removed; otherwise include them.
:param include_outliers: Whether to include outliers for float features.
:param outlier_iqr_multiplier: IQR multiplier used to trim outliers for float features.
:param ax: Axes in which to draw the plot. If None, a new one is created.
:param kwargs: Extra keyword arguments forwarded to the underlying plotting function
(``seaborn.violinplot``, ``Series.plot``, or ``seaborn.countplot``).
:return: The Axes object with the plot drawn onto it.
"""
if ax is None:
_, ax = plt.subplots()
feature_series = series.dropna() if remove_na else series
if pd.api.types.is_float_dtype(feature_series):
ax = _plot_clean_violin_distribution(feature_series, include_outliers, outlier_iqr_multiplier, ax, **kwargs)
elif pd.api.types.is_datetime64_any_dtype(feature_series):
feature_series.value_counts().sort_index().plot(kind="line", ax=ax, **kwargs)
labels = ax.get_xticks()
else:
sns.countplot(x=_copy_series_or_keep_top_10(feature_series), ax=ax, **kwargs)
labels = ax.get_xticklabels()
if not ax.get_title():
ax.set_title(f"{feature_series.name} ({feature_series.dtype})")
ax.set_xlabel("")
# Skip tick relabeling for float (violin) plots where x-ticks are hidden
if not pd.api.types.is_float_dtype(feature_series):
ticks_loc = ax.get_xticks()
ax.xaxis.set_major_locator(ticker.FixedLocator(ticks_loc))
ax.set_xticklabels(labels, rotation=45, ha="right")
if pd.api.types.is_datetime64_any_dtype(feature_series):
ax.xaxis.set_major_formatter(_convert_numbers_to_dates)
return ax
[docs]
def visualize_correlations(correlation_matrix: pd.DataFrame, *, ax: Optional[axes.Axes] = None, **kwargs) -> axes.Axes:
"""Compute and visualize pairwise correlations of columns, excluding NA/null values.
`Original Seaborn code <https://seaborn.pydata.org/examples/many_pairwise_correlations.html>`_.
:param correlation_matrix: The correlation matrix.
:param ax: Axes in which to draw the plot. If None, use the currently active Axes.
:param kwargs: Additional keyword arguments passed to seaborn's heatmap function.
:return: The Axes object with the plot drawn onto it.
"""
if ax is None:
_, ax = plt.subplots()
mask = np.triu(np.ones_like(correlation_matrix, dtype=bool))
sns.heatmap(correlation_matrix, mask=mask, annot=True, fmt=".3f", ax=ax, **kwargs)
return ax
[docs]
def plot_correlation_dendrogram(
correlation_matrix: pd.DataFrame,
cluster_distance_method: Union[str, Callable] = "average",
*,
ax: Optional[axes.Axes] = None,
**kwargs,
) -> axes.Axes:
"""Plot a dendrogram of the correlation matrix, showing hierarchically the most correlated variables.
`Original XAI code <https://github.com/EthicalML/XAI>`_.
:param correlation_matrix: The correlation matrix.
:param cluster_distance_method: Method for calculating the distance between newly formed clusters.
`Read more here <https://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.linkage.html>`_
:param ax: Axes in which to draw the plot. If None, use the currently active Axes.
:param kwargs: Additional keyword arguments passed to the dendrogram function.
:return: The Axes object with the plot drawn onto it.
"""
if ax is None:
_, ax = plt.subplots()
corr_condensed = squareform(1 - correlation_matrix)
z = linkage(corr_condensed, method=cluster_distance_method)
ax.set(**kwargs)
dendrogram(z, labels=correlation_matrix.columns.tolist(), orientation="left", ax=ax)
return ax
[docs]
def plot_features_interaction(
data: pd.DataFrame,
feature_1: str,
feature_2: str,
*,
include_outliers: bool = True,
outlier_iqr_multiplier: float = 1.5,
ax: Optional[axes.Axes] = None,
**kwargs,
) -> axes.Axes:
"""Plot the joint distribution between two features using type-aware defaults.
Behavior by dtypes of ``feature_1`` and ``feature_2``:
- If both are numeric: scatter plot.
- If one is datetime and the other numeric: line/scatter over time.
- If both are categorical-like: overlaid histograms per category.
- If one is categorical-like and the other numeric: violin plot by category.
For the categorical-vs-numeric case, you can optionally trim outliers from the
numeric feature using an IQR fence [Q1 - k*IQR, Q3 + k*IQR], where ``k`` is
controlled by ``outlier_iqr_multiplier``.
:param data: The input DataFrame where each feature is a column.
:param feature_1: Name of the first feature.
:param feature_2: Name of the second feature.
:param include_outliers: Whether to include values outside the IQR fence for
categorical-vs-numeric violin plots (default True).
:param outlier_iqr_multiplier: Multiplier ``k`` for the IQR fence when trimming
outliers in categorical-vs-numeric plots (default 1.5).
:param ax: Axes in which to draw the plot. If None, a new one is created.
:param kwargs: Additional keyword arguments forwarded to the underlying plotting
functions (e.g., ``seaborn.violinplot``, ``Axes.scatter``, ``Axes.plot``).
:return: The Axes object with the plot drawn onto it.
"""
if ax is None:
_, ax = plt.subplots()
dtype1 = data[feature_1].dtype
dtype2 = data[feature_2].dtype
if _is_categorical_like(dtype1):
_plot_categorical_feature1(
feature_1,
feature_2,
data,
dtype2,
include_outliers,
outlier_iqr_multiplier,
ax,
**kwargs,
)
elif pd.api.types.is_datetime64_any_dtype(dtype1):
_plot_datetime_feature1(feature_1, feature_2, data, dtype2, ax, **kwargs)
elif _is_categorical_like(dtype2):
_plot_categorical_vs_numeric(feature_2, feature_1, data, outlier_iqr_multiplier, include_outliers, ax, **kwargs)
elif pd.api.types.is_datetime64_any_dtype(dtype2):
_plot_xy(feature_2, feature_1, data, ax, **kwargs)
else:
_plot_numeric_features(feature_1, feature_2, data, ax, **kwargs)
return ax
def _is_categorical_like(dtype):
"""Check if the dtype is categorical-like (categorical, boolean, or object)."""
return (
isinstance(dtype, pd.CategoricalDtype)
or pd.api.types.is_bool_dtype(dtype)
or pd.api.types.is_object_dtype(dtype)
)
def _plot_categorical_feature1(
categorical_feature,
feature_2,
data,
dtype2,
include_outliers,
outlier_iqr_multiplier,
ax,
**kwargs,
):
"""Plot when the first feature is categorical-like."""
if _is_categorical_like(dtype2):
_plot_categorical_vs_categorical(categorical_feature, feature_2, data, ax, **kwargs)
elif pd.api.types.is_datetime64_any_dtype(dtype2):
_plot_categorical_vs_datetime(categorical_feature, feature_2, data, ax, **kwargs)
else:
_plot_categorical_vs_numeric(
categorical_feature,
feature_2,
data,
outlier_iqr_multiplier,
include_outliers,
ax,
**kwargs,
)
def _plot_xy(datetime_feature, other_feature, data, ax, **kwargs):
ax.plot(data[datetime_feature], data[other_feature], **kwargs)
ax.set_xlabel(datetime_feature)
ax.set_ylabel(other_feature)
def _plot_datetime_feature1(datetime_feature, feature_2, data, dtype2, ax, **kwargs):
"""Plot when the first feature is datetime."""
if _is_categorical_like(dtype2):
_plot_categorical_vs_datetime(feature_2, datetime_feature, data, ax, **kwargs)
else:
_plot_xy(datetime_feature, feature_2, data, ax, **kwargs)
def _plot_numeric_features(feature_1, feature_2, data, ax, **kwargs):
"""Plot when both features are numeric."""
ax.scatter(data[feature_1], data[feature_2], **kwargs)
ax.set_xlabel(feature_1)
ax.set_ylabel(feature_2)
def _plot_categorical_vs_categorical(feature_1, feature_2, data, ax, **kwargs):
"""Plot when both features are categorical-like."""
dup_df = pd.DataFrame()
dup_df[feature_1] = _copy_series_or_keep_top_10(data[feature_1])
dup_df[feature_2] = _copy_series_or_keep_top_10(data[feature_2])
group_feature_1 = dup_df[feature_1].unique().tolist()
ax.hist(
[dup_df.loc[dup_df[feature_1] == value, feature_2] for value in group_feature_1],
label=group_feature_1,
**kwargs,
)
ax.set_xlabel(feature_1)
ax.legend(title=feature_2)
def _plot_categorical_vs_datetime(categorical_feature, datetime_feature, data, ax, **kwargs):
"""Plot when one feature is categorical-like and the other is datetime.
Draws a violin plot across time buckets on the x-axis with categories on the
y-axis. This unified function expects the categorical feature name first and
the datetime feature name second.
"""
dup_df = pd.DataFrame()
dup_df[datetime_feature] = data[datetime_feature].apply(dates.date2num)
dup_df[categorical_feature] = _copy_series_or_keep_top_10(data[categorical_feature])
chart = sns.violinplot(x=datetime_feature, y=categorical_feature, data=dup_df, ax=ax, **kwargs)
ticks_loc = chart.get_xticks()
chart.xaxis.set_major_locator(ticker.FixedLocator(ticks_loc))
chart.set_xticklabels(chart.get_xticklabels(), rotation=45, ha="right")
ax.xaxis.set_major_formatter(_convert_numbers_to_dates)
def _plot_categorical_vs_numeric(
categorical_feature,
numeric_feature,
data,
outlier_iqr_multiplier,
include_outliers,
ax,
**kwargs,
):
"""Plot when the first feature is categorical-like and the second is numeric.
Renders a violin plot of the numeric feature for each category. When
``include_outliers`` is False, numeric values outside the IQR fence
[Q1 - k*IQR, Q3 + k*IQR] are trimmed, where ``k`` is ``outlier_iqr_multiplier``.
"""
dup_df = pd.DataFrame()
dup_df[categorical_feature] = _copy_series_or_keep_top_10(data[categorical_feature])
dup_df[numeric_feature] = data[numeric_feature]
if include_outliers:
df_plot = dup_df.copy()
else:
q1 = dup_df[numeric_feature].quantile(0.25)
q3 = dup_df[numeric_feature].quantile(0.75)
min_series_value = dup_df[numeric_feature].min()
max_series_value = dup_df[numeric_feature].max()
iqr = q3 - q1
lower_bound = max(min_series_value, q1 - outlier_iqr_multiplier * iqr)
upper_bound = min(max_series_value, q3 + outlier_iqr_multiplier * iqr)
df_plot = dup_df[(dup_df[numeric_feature] >= lower_bound) & (dup_df[numeric_feature] <= upper_bound)].copy()
sns.violinplot(x=categorical_feature, y=numeric_feature, hue=categorical_feature, data=df_plot, ax=ax, **kwargs)
ax.set_xlabel(categorical_feature.replace("_", " ").title())
ax.set_ylabel(numeric_feature.replace("_", " ").title())
ax.grid(axis="y", linestyle="--", alpha=0.7)
return ax
def _copy_series_or_keep_top_10(series: pd.Series) -> pd.Series:
if series.dtype == bool:
return series.map({True: "True", False: "False"})
if len(series.unique()) > 10:
top10 = series.value_counts().nlargest(10).index
return series.map(lambda x: x if x in top10 else "Other values")
return series
@plt.FuncFormatter
def _convert_numbers_to_dates(x, pos):
return dates.num2date(x).strftime("%Y-%m-%d %H:%M")