Coverage for credoai/modules/metrics.py: 85%

88 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-13 21:56 +0000

1import re 

2from dataclasses import dataclass 

3from typing import Any, Callable, List, Optional 

4 

5from credoai.artifacts.model.constants_model import MODEL_TYPES 

6from credoai.modules.constants_metrics import * 

7from credoai.modules.constants_threshold_metrics import * 

8from credoai.utils.common import ValidationError, humanize_label, wrap_list 

9 

10 

11@dataclass 

12class Metric: 

13 """Class to define metrics 

14 

15 Metric categories determine what kind of use the metric is designed for. Credo AI assumes 

16 that metric signatures either correspond with scikit-learn or fairlearn method signatures, 

17 in the case of binary/multiclass classification, regression, clustering and fairness metrics. 

18 

19 Dataset metrics are used as documentation placeholders and define no function. 

20 See DATASET_METRICS for examples. CUSTOM metrics have no expectations and will 

21 not be automatically used by Lens modules. 

22 

23 Metric Categories: 

24 

25 * {BINARY|MULTICLASS}_CLASSIFICATION: metrics like `scikit-learn's classification metrics <https://scikit-learn.org/stable/modules/model_evaluation.html>`_ 

26 * REGRESSION: metrics like `scikit-learn's regression metrics <https://scikit-learn.org/stable/modules/model_evaluation.html>`_ 

27 * CLUSTERING: metrics like `scikit-learn's clustering metrics <https://scikit-learn.org/stable/modules/model_evaluation.html>`_ 

28 * FAIRNESS: metrics like `fairlearn's equalized odds metric <https://fairlearn.org/v0.5.0/api_reference/fairlearn.metrics.html 

29 * DATASET: metrics intended 

30 * CUSTOM: No expectations for fun 

31 

32 Parameters 

33 ---------- 

34 name : str 

35 The primary name of the metric 

36 metric_category : str 

37 defined to be one of the METRIC_CATEGORIES, above 

38 fun : callable, optional 

39 The function definition of the metric. If none, the metric cannot be used and is only 

40 defined for documentation purposes 

41 takes_prob : bool, optional 

42 Whether the function takes the decision probabilities 

43 instead of the predicted class, as for ROC AUC. Similar to `needs_proba` used by 

44 `sklearn <https://scikit-learn.org/stable/modules/generated/sklearn.metrics.make_scorer.html>`_ 

45 by default False 

46 equivalent_names : list 

47 list of other names for metric 

48 """ 

49 

50 name: str 

51 metric_category: str 

52 fun: Optional[Callable[[Any], Any]] = None 

53 takes_prob: Optional[bool] = False 

54 equivalent_names: Optional[List[str]] = None 

55 

56 def __post_init__(self): 

57 if self.equivalent_names is None: 

58 self.equivalent_names = {self.name} 

59 else: 

60 self.equivalent_names = set(self.equivalent_names + [self.name]) 

61 self.metric_category = self.metric_category.upper() 

62 if self.metric_category not in METRIC_CATEGORIES: 

63 raise ValidationError(f"metric type ({self.metric_category}) isn't valid") 

64 self.humanized_type = humanize_label(self.name) 

65 

66 def __call__(self, **kwargs): 

67 self.fun(**kwargs) 

68 

69 def get_fun_doc(self): 

70 if self.fun: 

71 return self.fun.__doc__ 

72 

73 def print_fun_doc(self): 

74 print(self.get_fun_doc()) 

75 

76 def is_metric( 

77 self, metric_name: str, metric_categories: Optional[List[str]] = None 

78 ): 

79 metric_name = self.standardize_metric_name(metric_name) 

80 if self.equivalent_names: 

81 name_match = metric_name in self.equivalent_names 

82 if metric_categories is not None: 

83 return name_match and self.metric_category in metric_categories 

84 return name_match 

85 

86 def standardize_metric_name(self, metric): 

87 # standardize 

88 # lower, remove spaces, replace delimiters with underscores 

89 standard = "_".join(re.split("[- \s _]", re.sub("\s\s+", " ", metric.lower()))) 

90 return standard 

91 

92 

93def metrics_from_dict(dict, metric_category, probability_functions, metric_equivalents): 

94 # Convert to metric objects 

95 metrics = {} 

96 for metric_name, fun in dict.items(): 

97 equivalents = metric_equivalents.get(metric_name, []) # get equivalent names 

98 # whether the metric takes probabities instead of predictions 

99 takes_prob = metric_name in probability_functions 

100 metric = Metric(metric_name, metric_category, fun, takes_prob, equivalents) 

101 metrics[metric_name] = metric 

102 return metrics 

103 

104 

105def find_metrics(metric_name, metric_category=None): 

106 """Find metric by name and metric category 

107 

108 Parameters 

109 ---------- 

110 metric_name : str 

111 metric name to search for 

112 metric_category : str or list, optional 

113 category or list of metric categories to constrain search to. The list 

114 of metric categories is stored in modules.constants_metrics.METRIC_CATEGORIES, 

115 by default None 

116 

117 Returns 

118 ------- 

119 list 

120 list of Metrics 

121 """ 

122 if isinstance(metric_category, str): 

123 metric_category = [metric_category] 

124 matched_metrics = [ 

125 i for i in ALL_METRICS if i.is_metric(metric_name, metric_category) 

126 ] 

127 return matched_metrics 

128 

129 

130def find_single_metric(metric_name, metric_category=None): 

131 """As find_metrics, but enforce expectation that a single metric is returned""" 

132 matched_metric = find_metrics(metric_name, metric_category) 

133 if len(matched_metric) == 1: 

134 matched_metric = matched_metric[0] 

135 elif len(matched_metric) == 0: 

136 raise Exception( 

137 f"Returned no metrics when searching using the provided metric name <{metric_name}> with metric category <{metric_category}>. Expected to find one matching metric." 

138 ) 

139 else: 

140 raise Exception( 

141 f"Returned multiple metrics when searching using the provided metric name <{metric_name}> " 

142 f"with metric category <{metric_category}>. Expected to find only one matching metric. " 

143 "Try being more specific with the metric categories passed or using find_metrics if " 

144 "multiple metrics are desired." 

145 ) 

146 return matched_metric 

147 

148 

149def process_metrics(metrics, metric_categories=None): 

150 """Converts a list of metrics or strings into a standardized form 

151 

152 The standardized form is a dictionary of str: Metric, where the str represent 

153 a metric name. 

154 

155 Parameters 

156 ---------- 

157 metrics: list 

158 List of strings or Metrics 

159 metric_categories: str or list 

160 One or more metric categories to use to constrain string-based metric search 

161 (see modules.metrics.find_single_metric). The list 

162 of metric categories is stored in modules.constants_metrics.METRIC_CATEGORIES 

163 

164 Returns 

165 ------- 

166 processed_metrics: dict 

167 Standardized dictionary of metrics. Generally used to pass to 

168 evaluators.utils.fairlearn.setup_metric_frames 

169 fairness_metrics: dict 

170 Standardized dictionary of fairness metrics. Used for certain evaluator functions 

171 """ 

172 processed_metrics = {} 

173 fairness_metrics = {} 

174 metric_categories_to_include = MODEL_METRIC_CATEGORIES.copy() 

175 if metric_categories is not None: 

176 metric_categories_to_include += wrap_list(metric_categories) 

177 else: 

178 metric_categories_to_include += MODEL_TYPES 

179 

180 for metric in metrics: 

181 if isinstance(metric, str): 

182 metric_name = metric 

183 metric = find_single_metric(metric, metric_categories_to_include) 

184 else: 

185 metric_name = metric.name 

186 if not isinstance(metric, Metric): 

187 raise ValidationError( 

188 "Specified metric is not of type credoai.metric.Metric" 

189 ) 

190 if metric.metric_category == "FAIRNESS": 

191 fairness_metrics[metric_name] = metric 

192 else: 

193 processed_metrics[metric_name] = metric 

194 return processed_metrics, fairness_metrics 

195 

196 

197# Convert To List of Metrics 

198BINARY_CLASSIFICATION_METRICS = metrics_from_dict( 

199 BINARY_CLASSIFICATION_FUNCTIONS, 

200 "binary_classification", 

201 PROBABILITY_FUNCTIONS, 

202 METRIC_EQUIVALENTS, 

203) 

204 

205MULTICLASS_CLASSIFICATION_METRICS = metrics_from_dict( 

206 MULTICLASS_CLASSIFICATION_FUNCTIONS, 

207 "MULTICLASS_CLASSIFICATION", 

208 PROBABILITY_FUNCTIONS, 

209 METRIC_EQUIVALENTS, 

210) 

211 

212THRESHOLD_VARYING_METRICS = metrics_from_dict( 

213 THRESHOLD_PROBABILITY_FUNCTIONS, 

214 "BINARY_CLASSIFICATION_THRESHOLD", 

215 THRESHOLD_PROBABILITY_FUNCTIONS, 

216 THRESHOLD_METRIC_EQUIVALENTS, 

217) 

218 

219REGRESSION_METRICS = metrics_from_dict( 

220 REGRESSION_FUNCTIONS, "REGRESSION", PROBABILITY_FUNCTIONS, METRIC_EQUIVALENTS 

221) 

222 

223FAIRNESS_METRICS = metrics_from_dict( 

224 FAIRNESS_FUNCTIONS, "FAIRNESS", PROBABILITY_FUNCTIONS, METRIC_EQUIVALENTS 

225) 

226 

227DATASET_METRICS = {m: Metric(m, "DATASET", None, False) for m in DATASET_METRIC_TYPES} 

228 

229PRIVACY_METRICS = {m: Metric(m, "PRIVACY", None, False) for m in PRIVACY_METRIC_TYPES} 

230 

231SECURITY_METRICS = { 

232 m: Metric(m, "SECURITY", None, False) for m in SECURITY_METRIC_TYPES 

233} 

234 

235 

236METRIC_NAMES = ( 

237 list(BINARY_CLASSIFICATION_METRICS.keys()) 

238 + list(THRESHOLD_VARYING_METRICS.keys()) 

239 + list(FAIRNESS_METRICS.keys()) 

240 + list(DATASET_METRICS.keys()) 

241 + list(PRIVACY_METRICS.keys()) 

242 + list(SECURITY_METRICS.keys()) 

243 + list(REGRESSION_METRICS.keys()) 

244) 

245 

246ALL_METRICS = ( 

247 list(BINARY_CLASSIFICATION_METRICS.values()) 

248 + list(MULTICLASS_CLASSIFICATION_METRICS.values()) 

249 + list(THRESHOLD_VARYING_METRICS.values()) 

250 + list(FAIRNESS_METRICS.values()) 

251 + list(DATASET_METRICS.values()) 

252 + list(PRIVACY_METRICS.values()) 

253 + list(SECURITY_METRICS.values()) 

254 + list(REGRESSION_METRICS.values()) 

255)