Coverage for credoai/modules/constants_metrics.py: 100%
17 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-08 07:32 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-08 07:32 +0000
1"""Constants for threshold metrics
3Define relationships between metric names (strings) and
4metric functions, as well as alternate names for each metric name
5"""
7from functools import partial
9from fairlearn import metrics as fl_metrics
10from sklearn import metrics as sk_metrics
12from credoai.modules.metrics_credoai import (
13 equal_opportunity_difference,
14 false_discovery_rate,
15 false_omission_rate,
16 gini_coefficient_discriminatory,
17 ks_statistic,
18)
20THRESHOLD_METRIC_CATEGORIES = ["BINARY_CLASSIFICATION_THRESHOLD"]
22MODEL_METRIC_CATEGORIES = [
23 "BINARY_CLASSIFICATION",
24 "MULTICLASS_CLASSIFICATION",
25 "REGRESSION",
26 "CLUSTERING",
27 "FAIRNESS",
28] + THRESHOLD_METRIC_CATEGORIES
30NON_MODEL_METRIC_CATEGORIES = [
31 "PRIVACY",
32 "SECURITY",
33 "DATASET",
34 "CUSTOM",
35]
37METRIC_CATEGORIES = (
38 MODEL_METRIC_CATEGORIES + THRESHOLD_METRIC_CATEGORIES + NON_MODEL_METRIC_CATEGORIES
39)
41SCALAR_METRIC_CATEGORIES = MODEL_METRIC_CATEGORIES + NON_MODEL_METRIC_CATEGORIES
43# MODEL METRICS
44BINARY_CLASSIFICATION_FUNCTIONS = {
45 "false_positive_rate": fl_metrics.false_positive_rate,
46 "false_negative_rate": fl_metrics.false_negative_rate,
47 "false_discovery_rate": false_discovery_rate,
48 "false_omission_rate": false_omission_rate,
49 "true_positive_rate": fl_metrics.true_positive_rate,
50 "true_negative_rate": fl_metrics.true_negative_rate,
51 "precision_score": sk_metrics.precision_score,
52 "accuracy_score": sk_metrics.accuracy_score,
53 "balanced_accuracy_score": sk_metrics.balanced_accuracy_score,
54 "matthews_correlation_coefficient": sk_metrics.matthews_corrcoef,
55 "f1_score": sk_metrics.f1_score,
56 "average_precision_score": sk_metrics.average_precision_score,
57 "roc_auc_score": sk_metrics.roc_auc_score,
58 "selection_rate": fl_metrics.selection_rate,
59 "overprediction": fl_metrics._mean_overprediction,
60 "underprediction": fl_metrics._mean_underprediction,
61 "gini_coefficient": gini_coefficient_discriminatory,
62}
64# Define Fairness Metric Name Mapping
65# Fairness metrics must have a similar signature to fairlearn.metrics.equalized_odds_difference
66# (they should take sensitive_features and method)
67FAIRNESS_FUNCTIONS = {
68 "demographic_parity_difference": fl_metrics.demographic_parity_difference,
69 "demographic_parity_ratio": fl_metrics.demographic_parity_ratio,
70 "equalized_odds_difference": fl_metrics.equalized_odds_difference,
71 "equal_opportunity_difference": equal_opportunity_difference,
72}
75# Define functions that require probabilities ***
76PROBABILITY_FUNCTIONS = {"average_precision_score", "roc_auc_score", "gini_coefficient"}
78# *** Define Alternative Naming ***
79METRIC_EQUIVALENTS = {
80 "false_positive_rate": ["fpr", "fallout_rate", "false_match_rate"],
81 "false_negative_rate": ["fnr", "miss_rate", "false_non_match_rate"],
82 "false_discovery_rate": ["fdr"],
83 "true_positive_rate": ["tpr", "recall_score", "recall", "sensitivity", "hit_rate"],
84 "true_negative_rate": ["tnr", "specificity"],
85 "precision_score": ["precision"],
86 "demographic_parity_difference": ["statistical_parity", "demographic_parity"],
87 "demographic_parity_ratio": ["disparate_impact"],
88 "average_odds_difference": ["average_odds"],
89 "equal_opportunity_difference": ["equal_opportunity"],
90 "equalized_odds_difference": ["equalized_odds"],
91 "mean_absolute_error": ["MAE"],
92 "mean_squared_error": ["MSE", "MSD", "mean_squared_deviation"],
93 "root_mean_squared_error": ["RMSE"],
94 "r2_score": ["r_squared", "r2"],
95 "gini_coefficient": [
96 "gini_index",
97 "discriminatory_gini_index",
98 "discriminatory_gini",
99 ],
100 "population_stability_index": ["psi", "PSI"],
101}
103# DATASET METRICS
104DATASET_METRIC_TYPES = [
105 "sensitive_feature_prediction_score",
106 "demographic_parity_ratio",
107 "demographic_parity_difference",
108 "max_proxy_mutual_information",
109]
111# PRIVACY METRICS
112PRIVACY_METRIC_TYPES = [
113 "rule_based_attack_score",
114 "model_based_attack_score",
115 "membership_inference_attack_score",
116]
118# SECURITY METRICS
119SECURITY_METRIC_TYPES = ["extraction_attack_score", "evasion_attack_score"]
121# REGRESSION METRICS
122REGRESSION_FUNCTIONS = {
123 "explained_variance_score": sk_metrics.explained_variance_score,
124 "max_error": sk_metrics.max_error,
125 "mean_absolute_error": sk_metrics.mean_absolute_error,
126 "mean_squared_error": sk_metrics.mean_squared_error,
127 "root_mean_squared_error": partial(sk_metrics.mean_squared_error, squared=False),
128 "mean_squared_log_error": sk_metrics.mean_squared_log_error,
129 "mean_absolute_percentage_error": sk_metrics.mean_absolute_percentage_error,
130 "median_absolute_error": sk_metrics.median_absolute_error,
131 "r2_score": sk_metrics.r2_score,
132 "mean_poisson_deviance": sk_metrics.mean_poisson_deviance,
133 "mean_gamma_deviance": sk_metrics.mean_gamma_deviance,
134 "d2_tweedie_score": sk_metrics.d2_tweedie_score,
135 "mean_pinball_loss": sk_metrics.mean_pinball_loss,
136 "target_ks_statistic": ks_statistic,
137}