prepare_performance_data

prepare_performance_data(
    probs,
    reals,
    stratified_by=('probability_threshold',),
    by=0.01,
)

Prepare performance data for binary classification models.

This function computes a comprehensive set of performance metrics for one or more binary classification models across a range of probability thresholds. It builds upon the binned data from prepare_binned_classification_data by cumulatively summing the counts and calculating metrics like sensitivity (TPR), specificity, precision (PPV), and net benefit.

This resulting dataframe is the primary input for plotting functions like plot_roc_curve, plot_precision_recall_curve, etc.

Parameters

Name Type Description Default
probs Dict[str, np.ndarray] A dictionary mapping model or dataset names (str) to their predicted probabilities (1-D numpy arrays). required
reals Union[np.ndarray, Dict[str, np.ndarray]] The true event labels. This can be a single numpy array that is aligned with all pooled probabilities or a dictionary mapping each dataset name to its corresponding array of true labels. Labels must be binary (0 or 1). required
stratified_by Sequence[str] A sequence of strings specifying the variables by which to stratify the data. The default is ("probability_threshold",). ('probability_threshold',)
by float The step size for probability thresholds, determining the number of points at which performance is evaluated. Defaults to 0.01. 0.01

Returns

Name Type Description
pl.DataFrame A Polars DataFrame where each row corresponds to a probability cutoff for a given model/dataset. Columns include the cutoff value and a rich set of performance metrics (e.g., tpr, fpr, ppv, net_benefit).

Examples

>>> import numpy as np
>>> probs_dict_test = {
...     "small_data_set": np.array(
...         [0.9, 0.85, 0.95, 0.88, 0.6, 0.7, 0.51, 0.2, 0.1, 0.33]
...     )
... }
>>> reals_dict_test = [1, 1, 1, 1, 0, 0, 1, 0, 0, 1]
>>> performance_df = prepare_performance_data(
...     probs=probs_dict_test,
...     reals=reals_dict_test,
...     by=0.1
... )