Coverage for credoai/modules/constants_metrics.py: 100%

17 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-08 07:32 +0000

1"""Constants for threshold metrics 

2 

3Define relationships between metric names (strings) and 

4metric functions, as well as alternate names for each metric name 

5""" 

6 

7from functools import partial 

8 

9from fairlearn import metrics as fl_metrics 

10from sklearn import metrics as sk_metrics 

11 

12from credoai.modules.metrics_credoai import ( 

13 equal_opportunity_difference, 

14 false_discovery_rate, 

15 false_omission_rate, 

16 gini_coefficient_discriminatory, 

17 ks_statistic, 

18) 

19 

20THRESHOLD_METRIC_CATEGORIES = ["BINARY_CLASSIFICATION_THRESHOLD"] 

21 

22MODEL_METRIC_CATEGORIES = [ 

23 "BINARY_CLASSIFICATION", 

24 "MULTICLASS_CLASSIFICATION", 

25 "REGRESSION", 

26 "CLUSTERING", 

27 "FAIRNESS", 

28] + THRESHOLD_METRIC_CATEGORIES 

29 

30NON_MODEL_METRIC_CATEGORIES = [ 

31 "PRIVACY", 

32 "SECURITY", 

33 "DATASET", 

34 "CUSTOM", 

35] 

36 

37METRIC_CATEGORIES = ( 

38 MODEL_METRIC_CATEGORIES + THRESHOLD_METRIC_CATEGORIES + NON_MODEL_METRIC_CATEGORIES 

39) 

40 

41SCALAR_METRIC_CATEGORIES = MODEL_METRIC_CATEGORIES + NON_MODEL_METRIC_CATEGORIES 

42 

43# MODEL METRICS 

44BINARY_CLASSIFICATION_FUNCTIONS = { 

45 "false_positive_rate": fl_metrics.false_positive_rate, 

46 "false_negative_rate": fl_metrics.false_negative_rate, 

47 "false_discovery_rate": false_discovery_rate, 

48 "false_omission_rate": false_omission_rate, 

49 "true_positive_rate": fl_metrics.true_positive_rate, 

50 "true_negative_rate": fl_metrics.true_negative_rate, 

51 "precision_score": sk_metrics.precision_score, 

52 "accuracy_score": sk_metrics.accuracy_score, 

53 "balanced_accuracy_score": sk_metrics.balanced_accuracy_score, 

54 "matthews_correlation_coefficient": sk_metrics.matthews_corrcoef, 

55 "f1_score": sk_metrics.f1_score, 

56 "average_precision_score": sk_metrics.average_precision_score, 

57 "roc_auc_score": sk_metrics.roc_auc_score, 

58 "selection_rate": fl_metrics.selection_rate, 

59 "overprediction": fl_metrics._mean_overprediction, 

60 "underprediction": fl_metrics._mean_underprediction, 

61 "gini_coefficient": gini_coefficient_discriminatory, 

62} 

63 

64# Define Fairness Metric Name Mapping 

65# Fairness metrics must have a similar signature to fairlearn.metrics.equalized_odds_difference 

66# (they should take sensitive_features and method) 

67FAIRNESS_FUNCTIONS = { 

68 "demographic_parity_difference": fl_metrics.demographic_parity_difference, 

69 "demographic_parity_ratio": fl_metrics.demographic_parity_ratio, 

70 "equalized_odds_difference": fl_metrics.equalized_odds_difference, 

71 "equal_opportunity_difference": equal_opportunity_difference, 

72} 

73 

74 

75# Define functions that require probabilities *** 

76PROBABILITY_FUNCTIONS = {"average_precision_score", "roc_auc_score", "gini_coefficient"} 

77 

78# *** Define Alternative Naming *** 

79METRIC_EQUIVALENTS = { 

80 "false_positive_rate": ["fpr", "fallout_rate", "false_match_rate"], 

81 "false_negative_rate": ["fnr", "miss_rate", "false_non_match_rate"], 

82 "false_discovery_rate": ["fdr"], 

83 "true_positive_rate": ["tpr", "recall_score", "recall", "sensitivity", "hit_rate"], 

84 "true_negative_rate": ["tnr", "specificity"], 

85 "precision_score": ["precision"], 

86 "demographic_parity_difference": ["statistical_parity", "demographic_parity"], 

87 "demographic_parity_ratio": ["disparate_impact"], 

88 "average_odds_difference": ["average_odds"], 

89 "equal_opportunity_difference": ["equal_opportunity"], 

90 "equalized_odds_difference": ["equalized_odds"], 

91 "mean_absolute_error": ["MAE"], 

92 "mean_squared_error": ["MSE", "MSD", "mean_squared_deviation"], 

93 "root_mean_squared_error": ["RMSE"], 

94 "r2_score": ["r_squared", "r2"], 

95 "gini_coefficient": [ 

96 "gini_index", 

97 "discriminatory_gini_index", 

98 "discriminatory_gini", 

99 ], 

100 "population_stability_index": ["psi", "PSI"], 

101} 

102 

103# DATASET METRICS 

104DATASET_METRIC_TYPES = [ 

105 "sensitive_feature_prediction_score", 

106 "demographic_parity_ratio", 

107 "demographic_parity_difference", 

108 "max_proxy_mutual_information", 

109] 

110 

111# PRIVACY METRICS 

112PRIVACY_METRIC_TYPES = [ 

113 "rule_based_attack_score", 

114 "model_based_attack_score", 

115 "membership_inference_attack_score", 

116] 

117 

118# SECURITY METRICS 

119SECURITY_METRIC_TYPES = ["extraction_attack_score", "evasion_attack_score"] 

120 

121# REGRESSION METRICS 

122REGRESSION_FUNCTIONS = { 

123 "explained_variance_score": sk_metrics.explained_variance_score, 

124 "max_error": sk_metrics.max_error, 

125 "mean_absolute_error": sk_metrics.mean_absolute_error, 

126 "mean_squared_error": sk_metrics.mean_squared_error, 

127 "root_mean_squared_error": partial(sk_metrics.mean_squared_error, squared=False), 

128 "mean_squared_log_error": sk_metrics.mean_squared_log_error, 

129 "mean_absolute_percentage_error": sk_metrics.mean_absolute_percentage_error, 

130 "median_absolute_error": sk_metrics.median_absolute_error, 

131 "r2_score": sk_metrics.r2_score, 

132 "mean_poisson_deviance": sk_metrics.mean_poisson_deviance, 

133 "mean_gamma_deviance": sk_metrics.mean_gamma_deviance, 

134 "d2_tweedie_score": sk_metrics.d2_tweedie_score, 

135 "mean_pinball_loss": sk_metrics.mean_pinball_loss, 

136 "target_ks_statistic": ks_statistic, 

137}