Coverage for credoai/modules/metrics.py: 88%

60 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-08 07:32 +0000

1import re 

2from dataclasses import dataclass 

3from typing import Any, Callable, List, Optional 

4 

5from credoai.modules.constants_metrics import * 

6from credoai.modules.constants_threshold_metrics import * 

7from credoai.utils.common import ValidationError, humanize_label 

8 

9 

10@dataclass 

11class Metric: 

12 """Class to define metrics 

13 

14 Metric categories determine what kind of use the metric is designed for. Credo AI assumes 

15 that metric signatures either correspond with scikit-learn or fairlearn method signatures, 

16 in the case of binary/multiclass classification, regression, clustering and fairness metrics. 

17 

18 Dataset metrics are used as documentation placeholders and define no function. 

19 See DATASET_METRICS for examples. CUSTOM metrics have no expectations and will 

20 not be automatically used by Lens modules. 

21 

22 Metric Categories: 

23 

24 * {BINARY|MULTICLASS}_CLASSIFICATION: metrics like `scikit-learn's classification metrics <https://scikit-learn.org/stable/modules/model_evaluation.html>`_ 

25 * REGRESSION: metrics like `scikit-learn's regression metrics <https://scikit-learn.org/stable/modules/model_evaluation.html>`_ 

26 * CLUSTERING: metrics like `scikit-learn's clustering metrics <https://scikit-learn.org/stable/modules/model_evaluation.html>`_ 

27 * FAIRNESS: metrics like `fairlearn's equalized odds metric <https://fairlearn.org/v0.5.0/api_reference/fairlearn.metrics.html 

28 * DATASET: metrics intended 

29 * CUSTOM: No expectations for fun 

30 

31 Parameters 

32 ---------- 

33 name : str 

34 The primary name of the metric 

35 metric_category : str 

36 defined to be one of the METRIC_CATEGORIES, above 

37 fun : callable, optional 

38 The function definition of the metric. If none, the metric cannot be used and is only 

39 defined for documentation purposes 

40 takes_prob : bool, optional 

41 Whether the function takes the decision probabilities 

42 instead of the predicted class, as for ROC AUC. Similar to `needs_proba` used by 

43 `sklearn <https://scikit-learn.org/stable/modules/generated/sklearn.metrics.make_scorer.html>`_ 

44 by default False 

45 equivalent_names : list 

46 list of other names for metric 

47 """ 

48 

49 name: str 

50 metric_category: str 

51 fun: Optional[Callable[[Any], Any]] = None 

52 takes_prob: Optional[bool] = False 

53 equivalent_names: Optional[List[str]] = None 

54 

55 def __post_init__(self): 

56 if self.equivalent_names is None: 

57 self.equivalent_names = {self.name} 

58 else: 

59 self.equivalent_names = set(self.equivalent_names + [self.name]) 

60 self.metric_category = self.metric_category.upper() 

61 if self.metric_category not in METRIC_CATEGORIES: 

62 raise ValidationError(f"metric type ({self.metric_category}) isn't valid") 

63 self.humanized_type = humanize_label(self.name) 

64 

65 def __call__(self, **kwargs): 

66 self.fun(**kwargs) 

67 

68 def get_fun_doc(self): 

69 if self.fun: 

70 return self.fun.__doc__ 

71 

72 def print_fun_doc(self): 

73 print(self.get_fun_doc()) 

74 

75 def is_metric( 

76 self, metric_name: str, metric_categories: Optional[List[str]] = None 

77 ): 

78 metric_name = self.standardize_metric_name(metric_name) 

79 if self.equivalent_names: 

80 name_match = metric_name in self.equivalent_names 

81 if metric_categories is not None: 

82 return name_match and self.metric_category in metric_categories 

83 return name_match 

84 

85 def standardize_metric_name(self, metric): 

86 # standardize 

87 # lower, remove spaces, replace delimiters with underscores 

88 standard = "_".join(re.split("[- \s _]", re.sub("\s\s+", " ", metric.lower()))) 

89 return standard 

90 

91 

92def metrics_from_dict(dict, metric_category, probability_functions, metric_equivalents): 

93 # Convert to metric objects 

94 metrics = {} 

95 for metric_name, fun in dict.items(): 

96 equivalents = metric_equivalents.get(metric_name, []) # get equivalent names 

97 # whether the metric takes probabities instead of predictions 

98 takes_prob = metric_name in probability_functions 

99 metric = Metric(metric_name, metric_category, fun, takes_prob, equivalents) 

100 metrics[metric_name] = metric 

101 return metrics 

102 

103 

104def find_metrics(metric_name, metric_category=None): 

105 """Find metric by name and metric category 

106 

107 Parameters 

108 ---------- 

109 metric_name : str 

110 metric name to search for 

111 metric_category : str or list, optional 

112 category or list of categories to constrain search to, by default None 

113 

114 Returns 

115 ------- 

116 list 

117 list of Metrics 

118 """ 

119 if isinstance(metric_category, str): 

120 metric_category = [metric_category] 

121 matched_metrics = [ 

122 i for i in ALL_METRICS if i.is_metric(metric_name, metric_category) 

123 ] 

124 return matched_metrics 

125 

126 

127# Convert To List of Metrics 

128BINARY_CLASSIFICATION_METRICS = metrics_from_dict( 

129 BINARY_CLASSIFICATION_FUNCTIONS, 

130 "BINARY_CLASSIFICATION", 

131 PROBABILITY_FUNCTIONS, 

132 METRIC_EQUIVALENTS, 

133) 

134 

135THRESHOLD_VARYING_METRICS = metrics_from_dict( 

136 THRESHOLD_PROBABILITY_FUNCTIONS, 

137 "BINARY_CLASSIFICATION_THRESHOLD", 

138 THRESHOLD_PROBABILITY_FUNCTIONS, 

139 THRESHOLD_METRIC_EQUIVALENTS, 

140) 

141 

142REGRESSION_METRICS = metrics_from_dict( 

143 REGRESSION_FUNCTIONS, "REGRESSION", PROBABILITY_FUNCTIONS, METRIC_EQUIVALENTS 

144) 

145 

146FAIRNESS_METRICS = metrics_from_dict( 

147 FAIRNESS_FUNCTIONS, "FAIRNESS", PROBABILITY_FUNCTIONS, METRIC_EQUIVALENTS 

148) 

149 

150DATASET_METRICS = {m: Metric(m, "DATASET", None, False) for m in DATASET_METRIC_TYPES} 

151 

152PRIVACY_METRICS = {m: Metric(m, "PRIVACY", None, False) for m in PRIVACY_METRIC_TYPES} 

153 

154SECURITY_METRICS = { 

155 m: Metric(m, "SECURITY", None, False) for m in SECURITY_METRIC_TYPES 

156} 

157 

158 

159METRIC_NAMES = ( 

160 list(BINARY_CLASSIFICATION_METRICS.keys()) 

161 + list(THRESHOLD_VARYING_METRICS.keys()) 

162 + list(FAIRNESS_METRICS.keys()) 

163 + list(DATASET_METRICS.keys()) 

164 + list(PRIVACY_METRICS.keys()) 

165 + list(SECURITY_METRICS.keys()) 

166 + list(REGRESSION_METRICS.keys()) 

167) 

168 

169ALL_METRICS = ( 

170 list(BINARY_CLASSIFICATION_METRICS.values()) 

171 + list(THRESHOLD_VARYING_METRICS.values()) 

172 + list(FAIRNESS_METRICS.values()) 

173 + list(DATASET_METRICS.values()) 

174 + list(PRIVACY_METRICS.values()) 

175 + list(SECURITY_METRICS.values()) 

176 + list(REGRESSION_METRICS.values()) 

177)