Source code for ds_utils.metrics.curves

"""Module containing functions for plotting evaluation curves."""

from typing import Dict, Optional, Union
import warnings

import numpy as np
from plotly import graph_objects as go
from sklearn.exceptions import UndefinedMetricWarning
from sklearn.metrics import average_precision_score, precision_recall_curve, roc_auc_score, roc_curve


[docs] def plot_roc_curve_with_thresholds_annotations( y_true: np.ndarray, classifiers_names_and_scores_dict: Dict[str, np.ndarray], *, positive_label: Optional[Union[int, float, bool, str]] = None, sample_weight: Optional[np.ndarray] = None, drop_intermediate: bool = True, average: Optional[str] = "macro", max_fpr: Optional[float] = None, multi_class: str = "raise", labels: Optional[np.ndarray] = None, fig: Optional[go.Figure] = None, mode: Optional[str] = "lines+markers", add_random_classifier_line: bool = True, random_classifier_line_kw: Optional[Dict] = None, show_legend: bool = True, **kwargs, ) -> go.Figure: """Plot ROC curves with threshold annotations for multiple classifiers. :param y_true: array-like of shape (n_samples,). True binary labels. :param classifiers_names_and_scores_dict: mapping from classifier name to classifier's score. :param positive_label: int, float, bool or str, default=None. The label of the positive class. :param sample_weight: array-like of shape (n_samples,), default=None. Sample weights. :param drop_intermediate: bool, default=True. Whether to drop some suboptimal thresholds which would not appear on a plotted ROC curve. :param average: {'micro', 'macro', 'samples', 'weighted'} or None, default='macro'. If not None, this determines the type of averaging performed on the data. :param max_fpr: float > 0 and <= 1, default=None. If not None, the standardized partial AUC over the range [0, max_fpr] is returned. :param multi_class: {'raise', 'ovr', 'ovo'}, default='raise'. Determines the type of configuration to use for multiclass targets. :param labels: array-like of shape (n_classes,), default=None. Only used for multiclass targets. List of labels that index the classes in y_score. :param fig: plotly's Figure object, optional. The figure to plot on. :param mode: str, default='lines+markers'. Determines the drawing mode for this scatter trace. :param add_random_classifier_line: bool, default=True. Whether to plot a diagonal dashed black line which represents a random classifier. :param random_classifier_line_kw: dict, default=None. Keyword arguments to be passed to plotly's Scatter for rendering the random classifier line (e.g., line color, style). :param show_legend: bool, default=True. Whether to display legend in the plot. :param kwargs: additional keyword arguments to be passed to the plot function. :return: The Figure object with the plot drawn onto it. :raises ValueError: If the input data is invalid or inconsistent. """ if fig is None: fig = go.Figure() # Create a new figure if none is provided for classifier_name, y_scores in classifiers_names_and_scores_dict.items(): if y_true.shape != y_scores.shape: raise ValueError( f"Shape mismatch: y_true {y_true.shape} and y_scores {y_scores.shape} for classifier {classifier_name}" ) try: fpr_array, tpr_array, thresholds = roc_curve( y_true, y_scores, pos_label=positive_label, sample_weight=sample_weight, drop_intermediate=drop_intermediate, ) except ValueError as e: raise ValueError(f"Error calculating ROC curve for classifier {classifier_name}: {str(e)}") try: with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UndefinedMetricWarning) auc_score = roc_auc_score( y_true, y_scores, average=average, sample_weight=sample_weight, max_fpr=max_fpr, multi_class=multi_class, labels=labels, ) except ValueError as e: raise ValueError(f"Error calculating AUC score for classifier {classifier_name}: {str(e)}") fig.add_trace( go.Scatter( x=fpr_array, y=tpr_array, mode=mode, text=[ f"Prob: {threshold:.2f}<br>FPR: {fpr:.2f}<br>TPR: {tpr:.2f}" for fpr, tpr, threshold in zip(fpr_array, tpr_array, thresholds) ], hoverinfo="text", name=f"{classifier_name} (AUC = {auc_score:.2f})", **kwargs, ) ) if add_random_classifier_line: # Add dashed line for random classifier default_random_classifier_kw = { "line": dict(dash="dash", color="black"), "name": "Random Classifier (AUC = 0.50)", } if random_classifier_line_kw is not None: default_random_classifier_kw.update(random_classifier_line_kw) fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], mode="lines", hoverinfo="name", **default_random_classifier_kw)) fig.update_layout(xaxis_title="False Positive Rate", yaxis_title="True Positive Rate", showlegend=show_legend) return fig
[docs] def plot_precision_recall_curve_with_thresholds_annotations( y_true: np.ndarray, classifiers_names_and_scores_dict: Dict[str, np.ndarray], *, positive_label: Optional[Union[int, float, bool, str]] = None, sample_weight: Optional[np.ndarray] = None, drop_intermediate: bool = True, fig: Optional[go.Figure] = None, mode: Optional[str] = "lines+markers", plot_chance_level: bool = False, chance_level_kw: Optional[Dict] = None, show_legend: bool = True, **kwargs, ) -> go.Figure: """Plot Precision-Recall curves with threshold annotations for multiple classifiers. :param y_true: array-like of shape (n_samples,). True binary labels. :param classifiers_names_and_scores_dict: mapping from classifier name to classifier's score. :param positive_label: int, float, bool or str, default=None. The label of the positive class. :param sample_weight: array-like of shape (n_samples,), default=None. Sample weights. :param drop_intermediate: bool, default=True. Whether to drop some suboptimal thresholds that don't change the precision. This is useful to create lighter Precision-Recall curves. :param fig: plotly's Figure object, optional. The figure to plot on. :param mode: str, default='lines+markers'. Determines the drawing mode for this scatter trace. :param plot_chance_level: bool, default=False. Whether to plot the chance level. The chance level is the prevalence of the positive label computed from the data passed. Behavior is like sklearn: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.PrecisionRecallDisplay.html When positive_label is None, the positive class is inferred as the larger of the two unique labels in y_true, consistent with scikit-learn's convention. :param chance_level_kw: dict, default=None. Keyword arguments to be passed to plotly's Scatter for rendering the chance level line (e.g., line color, style). :param show_legend: bool, default=True. Whether to display legend in the plot. :param kwargs: additional keyword arguments to be passed to the plot function. :return: The Figure object with the plot drawn onto it. :raises ValueError: If the input data is invalid or inconsistent. """ if fig is None: fig = go.Figure() # Create a new figure if none is provided # When plot_chance_level=True and positive_label=None, resolve to the larger unique label # (sklearn convention) and validate binary input. Otherwise, keep positive_label as-is — # sklearn will apply its own default when pos_label=None. effective_positive_label = positive_label if plot_chance_level and effective_positive_label is None: unique_labels = np.unique(y_true) if len(unique_labels) != 2: raise ValueError("y_true must be binary for plotting chance level") # Use the convention that the larger label is positive (consistent with sklearn) effective_positive_label = unique_labels[1] for classifier_name, y_scores in classifiers_names_and_scores_dict.items(): if y_true.shape != y_scores.shape: raise ValueError( f"Shape mismatch: y_true {y_true.shape} and y_scores {y_scores.shape} for classifier {classifier_name}" ) try: precision_array, recall_array, thresholds = precision_recall_curve( y_true, y_scores, pos_label=effective_positive_label, sample_weight=sample_weight, drop_intermediate=drop_intermediate, ) except ValueError as e: raise ValueError(f"Error calculating Precision-Recall curve for classifier {classifier_name}: {str(e)}") try: ap_kwargs = {} if effective_positive_label is not None: ap_kwargs["pos_label"] = effective_positive_label if sample_weight is not None: ap_kwargs["sample_weight"] = sample_weight ap = average_precision_score(y_true, y_scores, **ap_kwargs) except ValueError as e: raise ValueError(f"Error calculating Average Precision for classifier {classifier_name}: {str(e)}") display_thresholds = np.append(thresholds, np.nan) fig.add_trace( go.Scatter( x=recall_array, y=precision_array, mode=mode, text=[ f"Prob: {'N/A' if np.isnan(t) else f'{t:.2f}'}<br>Precision: {p:.2f}<br>Recall: {r:.2f}" for p, r, t in zip(precision_array, recall_array, display_thresholds) ], hoverinfo="text", name=f"{classifier_name} (AP = {ap:0.2f})", **kwargs, ) ) if plot_chance_level: # Note: AP here refers to the chance-level Average Precision, which equals prevalence. # This naming convention is consistent with scikit-learn's PrecisionRecallDisplay. if sample_weight is not None: prevalence = np.sum(sample_weight[y_true == effective_positive_label]) / np.sum(sample_weight) else: prevalence = np.sum(y_true == effective_positive_label) / len(y_true) # Default styling for chance level line default_chance_level_kw = { "line": dict(dash="dash", color="black"), "name": f"Chance level (AP = {prevalence:0.2f})", } if chance_level_kw is not None: default_chance_level_kw.update(chance_level_kw) # Plot horizontal line at prevalence fig.add_trace( go.Scatter( x=[0, 1], y=[prevalence, prevalence], mode="lines", hoverinfo="name", **default_chance_level_kw, ) ) fig.update_layout(xaxis_title="Recall", yaxis_title="Precision", showlegend=show_legend) return fig