Coverage for credoai/modules/constants_metrics.py: 100%
19 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-13 21:56 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-13 21:56 +0000
1"""Constants for threshold metrics
3Define relationships between metric names (strings) and
4metric functions, as well as alternate names for each metric name
5"""
7from functools import partial
9from fairlearn import metrics as fl_metrics
10from sklearn import metrics as sk_metrics
12from credoai.artifacts.model.constants_model import MODEL_TYPES
13from credoai.modules.metrics_credoai import (
14 equal_opportunity_difference,
15 false_discovery_rate,
16 false_omission_rate,
17 gini_coefficient_discriminatory,
18 ks_statistic,
19 multiclass_confusion_metrics,
20)
22THRESHOLD_METRIC_CATEGORIES = ["BINARY_CLASSIFICATION_THRESHOLD"]
24MODEL_METRIC_CATEGORIES = [
25 "CLUSTERING",
26 "FAIRNESS",
27] + THRESHOLD_METRIC_CATEGORIES
29NON_MODEL_METRIC_CATEGORIES = [
30 "PRIVACY",
31 "SECURITY",
32 "DATASET",
33 "CUSTOM",
34]
36METRIC_CATEGORIES = (
37 MODEL_TYPES
38 + MODEL_METRIC_CATEGORIES
39 + THRESHOLD_METRIC_CATEGORIES
40 + NON_MODEL_METRIC_CATEGORIES
41)
43SCALAR_METRIC_CATEGORIES = MODEL_METRIC_CATEGORIES + NON_MODEL_METRIC_CATEGORIES
45# MODEL METRICS
46# Define Binary classification name mapping.
47# Binary classification metrics must have a similar signature to sklearn metrics
48BINARY_CLASSIFICATION_FUNCTIONS = {
49 "accuracy_score": sk_metrics.accuracy_score,
50 "average_precision_score": sk_metrics.average_precision_score,
51 "balanced_accuracy_score": sk_metrics.balanced_accuracy_score,
52 "f1_score": sk_metrics.f1_score,
53 "false_discovery_rate": false_discovery_rate,
54 "false_negative_rate": fl_metrics.false_negative_rate,
55 "false_omission_rate": false_omission_rate,
56 "false_positive_rate": fl_metrics.false_positive_rate,
57 "gini_coefficient": gini_coefficient_discriminatory,
58 "matthews_correlation_coefficient": sk_metrics.matthews_corrcoef,
59 "overprediction": fl_metrics._mean_overprediction,
60 "precision_score": sk_metrics.precision_score,
61 "roc_auc_score": sk_metrics.roc_auc_score,
62 "selection_rate": fl_metrics.selection_rate,
63 "true_negative_rate": fl_metrics.true_negative_rate,
64 "true_positive_rate": fl_metrics.true_positive_rate,
65 "underprediction": fl_metrics._mean_underprediction,
66}
68# Define Multiclass classification name mapping.
69# Multiclass classification metrics must have a similar signature to sklearn metrics
70MULTICLASS_CLASSIFICATION_FUNCTIONS = {
71 "accuracy_score": partial(multiclass_confusion_metrics, metric="ACC"),
72 "balanced_accuracy_score": sk_metrics.balanced_accuracy_score,
73 "f1_score": partial(sk_metrics.f1_score, average="weighted"),
74 "false_discovery_rate": partial(multiclass_confusion_metrics, metric="FDR"),
75 "false_negative_rate": partial(multiclass_confusion_metrics, metric="FNR"),
76 "false_positive_rate": partial(multiclass_confusion_metrics, metric="FPR"),
77 "gini_coefficient": partial(
78 gini_coefficient_discriminatory, multi_class="ovo", average="weighted"
79 ),
80 "matthews_correlation_coefficient": sk_metrics.matthews_corrcoef,
81 "overprediction": fl_metrics._mean_overprediction,
82 "precision_score": partial(sk_metrics.precision_score, average="weighted"),
83 "roc_auc_score": partial(
84 sk_metrics.roc_auc_score, multi_class="ovo", average="weighted"
85 ),
86 "selection_rate": fl_metrics.selection_rate,
87 "true_negative_rate": partial(multiclass_confusion_metrics, metric="TNR"),
88 "true_positive_rate": partial(multiclass_confusion_metrics, metric="TPR"),
89 "underprediction": fl_metrics._mean_underprediction,
90}
92# Define Fairness Metric Name Mapping
93# Fairness metrics must have a similar signature to fairlearn.metrics.equalized_odds_difference
94# (they should take sensitive_features and method)
95FAIRNESS_FUNCTIONS = {
96 "demographic_parity_difference": fl_metrics.demographic_parity_difference,
97 "demographic_parity_ratio": fl_metrics.demographic_parity_ratio,
98 "equalized_odds_difference": fl_metrics.equalized_odds_difference,
99 "equal_opportunity_difference": equal_opportunity_difference,
100}
103# Define functions that require probabilities ***
104PROBABILITY_FUNCTIONS = {"average_precision_score", "roc_auc_score", "gini_coefficient"}
106# *** Define Alternative Naming ***
107METRIC_EQUIVALENTS = {
108 "average_odds_difference": ["average_odds"],
109 "average_precision_score": ["average_precision"],
110 "demographic_parity_difference": ["statistical_parity", "demographic_parity"],
111 "demographic_parity_ratio": ["disparate_impact"],
112 "equal_opportunity_difference": ["equal_opportunity"],
113 "equalized_odds_difference": ["equalized_odds"],
114 "false_positive_rate": ["fpr", "fallout_rate", "false_match_rate"],
115 "false_negative_rate": ["fnr", "miss_rate", "false_non_match_rate"],
116 "false_discovery_rate": ["fdr"],
117 "gini_coefficient": [
118 "gini_index",
119 "discriminatory_gini_index",
120 "discriminatory_gini",
121 ],
122 "mean_absolute_error": ["MAE"],
123 "mean_squared_error": ["MSE", "MSD", "mean_squared_deviation"],
124 "population_stability_index": ["psi", "PSI"],
125 "precision_score": ["precision"],
126 "root_mean_squared_error": ["RMSE"],
127 "r2_score": ["r_squared", "r2"],
128 "true_positive_rate": ["tpr", "recall_score", "recall", "sensitivity", "hit_rate"],
129 "true_negative_rate": ["tnr", "specificity"],
130}
132# DATASET METRICS
133DATASET_METRIC_TYPES = [
134 "sensitive_feature_prediction_score",
135 "demographic_parity_ratio",
136 "demographic_parity_difference",
137 "max_proxy_mutual_information",
138]
140# PRIVACY METRICS
141PRIVACY_METRIC_TYPES = [
142 "rule_based_attack_score",
143 "model_based_attack_score",
144 "membership_inference_attack_score",
145]
147# SECURITY METRICS
148SECURITY_METRIC_TYPES = ["extraction_attack_score", "evasion_attack_score"]
150# REGRESSION METRICS
151REGRESSION_FUNCTIONS = {
152 "explained_variance_score": sk_metrics.explained_variance_score,
153 "max_error": sk_metrics.max_error,
154 "mean_absolute_error": sk_metrics.mean_absolute_error,
155 "mean_squared_error": sk_metrics.mean_squared_error,
156 "root_mean_squared_error": partial(sk_metrics.mean_squared_error, squared=False),
157 "mean_squared_log_error": sk_metrics.mean_squared_log_error,
158 "mean_absolute_percentage_error": sk_metrics.mean_absolute_percentage_error,
159 "median_absolute_error": sk_metrics.median_absolute_error,
160 "r2_score": sk_metrics.r2_score,
161 "mean_poisson_deviance": sk_metrics.mean_poisson_deviance,
162 "mean_gamma_deviance": sk_metrics.mean_gamma_deviance,
163 "d2_tweedie_score": sk_metrics.d2_tweedie_score,
164 "mean_pinball_loss": sk_metrics.mean_pinball_loss,
165 "target_ks_statistic": ks_statistic,
166}