Coverage for credoai/modules/constants_metrics.py: 100%

19 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-13 21:56 +0000

1"""Constants for threshold metrics 

2 

3Define relationships between metric names (strings) and 

4metric functions, as well as alternate names for each metric name 

5""" 

6 

7from functools import partial 

8 

9from fairlearn import metrics as fl_metrics 

10from sklearn import metrics as sk_metrics 

11 

12from credoai.artifacts.model.constants_model import MODEL_TYPES 

13from credoai.modules.metrics_credoai import ( 

14 equal_opportunity_difference, 

15 false_discovery_rate, 

16 false_omission_rate, 

17 gini_coefficient_discriminatory, 

18 ks_statistic, 

19 multiclass_confusion_metrics, 

20) 

21 

22THRESHOLD_METRIC_CATEGORIES = ["BINARY_CLASSIFICATION_THRESHOLD"] 

23 

24MODEL_METRIC_CATEGORIES = [ 

25 "CLUSTERING", 

26 "FAIRNESS", 

27] + THRESHOLD_METRIC_CATEGORIES 

28 

29NON_MODEL_METRIC_CATEGORIES = [ 

30 "PRIVACY", 

31 "SECURITY", 

32 "DATASET", 

33 "CUSTOM", 

34] 

35 

36METRIC_CATEGORIES = ( 

37 MODEL_TYPES 

38 + MODEL_METRIC_CATEGORIES 

39 + THRESHOLD_METRIC_CATEGORIES 

40 + NON_MODEL_METRIC_CATEGORIES 

41) 

42 

43SCALAR_METRIC_CATEGORIES = MODEL_METRIC_CATEGORIES + NON_MODEL_METRIC_CATEGORIES 

44 

45# MODEL METRICS 

46# Define Binary classification name mapping. 

47# Binary classification metrics must have a similar signature to sklearn metrics 

48BINARY_CLASSIFICATION_FUNCTIONS = { 

49 "accuracy_score": sk_metrics.accuracy_score, 

50 "average_precision_score": sk_metrics.average_precision_score, 

51 "balanced_accuracy_score": sk_metrics.balanced_accuracy_score, 

52 "f1_score": sk_metrics.f1_score, 

53 "false_discovery_rate": false_discovery_rate, 

54 "false_negative_rate": fl_metrics.false_negative_rate, 

55 "false_omission_rate": false_omission_rate, 

56 "false_positive_rate": fl_metrics.false_positive_rate, 

57 "gini_coefficient": gini_coefficient_discriminatory, 

58 "matthews_correlation_coefficient": sk_metrics.matthews_corrcoef, 

59 "overprediction": fl_metrics._mean_overprediction, 

60 "precision_score": sk_metrics.precision_score, 

61 "roc_auc_score": sk_metrics.roc_auc_score, 

62 "selection_rate": fl_metrics.selection_rate, 

63 "true_negative_rate": fl_metrics.true_negative_rate, 

64 "true_positive_rate": fl_metrics.true_positive_rate, 

65 "underprediction": fl_metrics._mean_underprediction, 

66} 

67 

68# Define Multiclass classification name mapping. 

69# Multiclass classification metrics must have a similar signature to sklearn metrics 

70MULTICLASS_CLASSIFICATION_FUNCTIONS = { 

71 "accuracy_score": partial(multiclass_confusion_metrics, metric="ACC"), 

72 "balanced_accuracy_score": sk_metrics.balanced_accuracy_score, 

73 "f1_score": partial(sk_metrics.f1_score, average="weighted"), 

74 "false_discovery_rate": partial(multiclass_confusion_metrics, metric="FDR"), 

75 "false_negative_rate": partial(multiclass_confusion_metrics, metric="FNR"), 

76 "false_positive_rate": partial(multiclass_confusion_metrics, metric="FPR"), 

77 "gini_coefficient": partial( 

78 gini_coefficient_discriminatory, multi_class="ovo", average="weighted" 

79 ), 

80 "matthews_correlation_coefficient": sk_metrics.matthews_corrcoef, 

81 "overprediction": fl_metrics._mean_overprediction, 

82 "precision_score": partial(sk_metrics.precision_score, average="weighted"), 

83 "roc_auc_score": partial( 

84 sk_metrics.roc_auc_score, multi_class="ovo", average="weighted" 

85 ), 

86 "selection_rate": fl_metrics.selection_rate, 

87 "true_negative_rate": partial(multiclass_confusion_metrics, metric="TNR"), 

88 "true_positive_rate": partial(multiclass_confusion_metrics, metric="TPR"), 

89 "underprediction": fl_metrics._mean_underprediction, 

90} 

91 

92# Define Fairness Metric Name Mapping 

93# Fairness metrics must have a similar signature to fairlearn.metrics.equalized_odds_difference 

94# (they should take sensitive_features and method) 

95FAIRNESS_FUNCTIONS = { 

96 "demographic_parity_difference": fl_metrics.demographic_parity_difference, 

97 "demographic_parity_ratio": fl_metrics.demographic_parity_ratio, 

98 "equalized_odds_difference": fl_metrics.equalized_odds_difference, 

99 "equal_opportunity_difference": equal_opportunity_difference, 

100} 

101 

102 

103# Define functions that require probabilities *** 

104PROBABILITY_FUNCTIONS = {"average_precision_score", "roc_auc_score", "gini_coefficient"} 

105 

106# *** Define Alternative Naming *** 

107METRIC_EQUIVALENTS = { 

108 "average_odds_difference": ["average_odds"], 

109 "average_precision_score": ["average_precision"], 

110 "demographic_parity_difference": ["statistical_parity", "demographic_parity"], 

111 "demographic_parity_ratio": ["disparate_impact"], 

112 "equal_opportunity_difference": ["equal_opportunity"], 

113 "equalized_odds_difference": ["equalized_odds"], 

114 "false_positive_rate": ["fpr", "fallout_rate", "false_match_rate"], 

115 "false_negative_rate": ["fnr", "miss_rate", "false_non_match_rate"], 

116 "false_discovery_rate": ["fdr"], 

117 "gini_coefficient": [ 

118 "gini_index", 

119 "discriminatory_gini_index", 

120 "discriminatory_gini", 

121 ], 

122 "mean_absolute_error": ["MAE"], 

123 "mean_squared_error": ["MSE", "MSD", "mean_squared_deviation"], 

124 "population_stability_index": ["psi", "PSI"], 

125 "precision_score": ["precision"], 

126 "root_mean_squared_error": ["RMSE"], 

127 "r2_score": ["r_squared", "r2"], 

128 "true_positive_rate": ["tpr", "recall_score", "recall", "sensitivity", "hit_rate"], 

129 "true_negative_rate": ["tnr", "specificity"], 

130} 

131 

132# DATASET METRICS 

133DATASET_METRIC_TYPES = [ 

134 "sensitive_feature_prediction_score", 

135 "demographic_parity_ratio", 

136 "demographic_parity_difference", 

137 "max_proxy_mutual_information", 

138] 

139 

140# PRIVACY METRICS 

141PRIVACY_METRIC_TYPES = [ 

142 "rule_based_attack_score", 

143 "model_based_attack_score", 

144 "membership_inference_attack_score", 

145] 

146 

147# SECURITY METRICS 

148SECURITY_METRIC_TYPES = ["extraction_attack_score", "evasion_attack_score"] 

149 

150# REGRESSION METRICS 

151REGRESSION_FUNCTIONS = { 

152 "explained_variance_score": sk_metrics.explained_variance_score, 

153 "max_error": sk_metrics.max_error, 

154 "mean_absolute_error": sk_metrics.mean_absolute_error, 

155 "mean_squared_error": sk_metrics.mean_squared_error, 

156 "root_mean_squared_error": partial(sk_metrics.mean_squared_error, squared=False), 

157 "mean_squared_log_error": sk_metrics.mean_squared_log_error, 

158 "mean_absolute_percentage_error": sk_metrics.mean_absolute_percentage_error, 

159 "median_absolute_error": sk_metrics.median_absolute_error, 

160 "r2_score": sk_metrics.r2_score, 

161 "mean_poisson_deviance": sk_metrics.mean_poisson_deviance, 

162 "mean_gamma_deviance": sk_metrics.mean_gamma_deviance, 

163 "d2_tweedie_score": sk_metrics.d2_tweedie_score, 

164 "mean_pinball_loss": sk_metrics.mean_pinball_loss, 

165 "target_ks_statistic": ks_statistic, 

166}