Coverage for credoai/modules/metrics.py: 85%
88 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-13 21:56 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-13 21:56 +0000
1import re
2from dataclasses import dataclass
3from typing import Any, Callable, List, Optional
5from credoai.artifacts.model.constants_model import MODEL_TYPES
6from credoai.modules.constants_metrics import *
7from credoai.modules.constants_threshold_metrics import *
8from credoai.utils.common import ValidationError, humanize_label, wrap_list
11@dataclass
12class Metric:
13 """Class to define metrics
15 Metric categories determine what kind of use the metric is designed for. Credo AI assumes
16 that metric signatures either correspond with scikit-learn or fairlearn method signatures,
17 in the case of binary/multiclass classification, regression, clustering and fairness metrics.
19 Dataset metrics are used as documentation placeholders and define no function.
20 See DATASET_METRICS for examples. CUSTOM metrics have no expectations and will
21 not be automatically used by Lens modules.
23 Metric Categories:
25 * {BINARY|MULTICLASS}_CLASSIFICATION: metrics like `scikit-learn's classification metrics <https://scikit-learn.org/stable/modules/model_evaluation.html>`_
26 * REGRESSION: metrics like `scikit-learn's regression metrics <https://scikit-learn.org/stable/modules/model_evaluation.html>`_
27 * CLUSTERING: metrics like `scikit-learn's clustering metrics <https://scikit-learn.org/stable/modules/model_evaluation.html>`_
28 * FAIRNESS: metrics like `fairlearn's equalized odds metric <https://fairlearn.org/v0.5.0/api_reference/fairlearn.metrics.html
29 * DATASET: metrics intended
30 * CUSTOM: No expectations for fun
32 Parameters
33 ----------
34 name : str
35 The primary name of the metric
36 metric_category : str
37 defined to be one of the METRIC_CATEGORIES, above
38 fun : callable, optional
39 The function definition of the metric. If none, the metric cannot be used and is only
40 defined for documentation purposes
41 takes_prob : bool, optional
42 Whether the function takes the decision probabilities
43 instead of the predicted class, as for ROC AUC. Similar to `needs_proba` used by
44 `sklearn <https://scikit-learn.org/stable/modules/generated/sklearn.metrics.make_scorer.html>`_
45 by default False
46 equivalent_names : list
47 list of other names for metric
48 """
50 name: str
51 metric_category: str
52 fun: Optional[Callable[[Any], Any]] = None
53 takes_prob: Optional[bool] = False
54 equivalent_names: Optional[List[str]] = None
56 def __post_init__(self):
57 if self.equivalent_names is None:
58 self.equivalent_names = {self.name}
59 else:
60 self.equivalent_names = set(self.equivalent_names + [self.name])
61 self.metric_category = self.metric_category.upper()
62 if self.metric_category not in METRIC_CATEGORIES:
63 raise ValidationError(f"metric type ({self.metric_category}) isn't valid")
64 self.humanized_type = humanize_label(self.name)
66 def __call__(self, **kwargs):
67 self.fun(**kwargs)
69 def get_fun_doc(self):
70 if self.fun:
71 return self.fun.__doc__
73 def print_fun_doc(self):
74 print(self.get_fun_doc())
76 def is_metric(
77 self, metric_name: str, metric_categories: Optional[List[str]] = None
78 ):
79 metric_name = self.standardize_metric_name(metric_name)
80 if self.equivalent_names:
81 name_match = metric_name in self.equivalent_names
82 if metric_categories is not None:
83 return name_match and self.metric_category in metric_categories
84 return name_match
86 def standardize_metric_name(self, metric):
87 # standardize
88 # lower, remove spaces, replace delimiters with underscores
89 standard = "_".join(re.split("[- \s _]", re.sub("\s\s+", " ", metric.lower())))
90 return standard
93def metrics_from_dict(dict, metric_category, probability_functions, metric_equivalents):
94 # Convert to metric objects
95 metrics = {}
96 for metric_name, fun in dict.items():
97 equivalents = metric_equivalents.get(metric_name, []) # get equivalent names
98 # whether the metric takes probabities instead of predictions
99 takes_prob = metric_name in probability_functions
100 metric = Metric(metric_name, metric_category, fun, takes_prob, equivalents)
101 metrics[metric_name] = metric
102 return metrics
105def find_metrics(metric_name, metric_category=None):
106 """Find metric by name and metric category
108 Parameters
109 ----------
110 metric_name : str
111 metric name to search for
112 metric_category : str or list, optional
113 category or list of metric categories to constrain search to. The list
114 of metric categories is stored in modules.constants_metrics.METRIC_CATEGORIES,
115 by default None
117 Returns
118 -------
119 list
120 list of Metrics
121 """
122 if isinstance(metric_category, str):
123 metric_category = [metric_category]
124 matched_metrics = [
125 i for i in ALL_METRICS if i.is_metric(metric_name, metric_category)
126 ]
127 return matched_metrics
130def find_single_metric(metric_name, metric_category=None):
131 """As find_metrics, but enforce expectation that a single metric is returned"""
132 matched_metric = find_metrics(metric_name, metric_category)
133 if len(matched_metric) == 1:
134 matched_metric = matched_metric[0]
135 elif len(matched_metric) == 0:
136 raise Exception(
137 f"Returned no metrics when searching using the provided metric name <{metric_name}> with metric category <{metric_category}>. Expected to find one matching metric."
138 )
139 else:
140 raise Exception(
141 f"Returned multiple metrics when searching using the provided metric name <{metric_name}> "
142 f"with metric category <{metric_category}>. Expected to find only one matching metric. "
143 "Try being more specific with the metric categories passed or using find_metrics if "
144 "multiple metrics are desired."
145 )
146 return matched_metric
149def process_metrics(metrics, metric_categories=None):
150 """Converts a list of metrics or strings into a standardized form
152 The standardized form is a dictionary of str: Metric, where the str represent
153 a metric name.
155 Parameters
156 ----------
157 metrics: list
158 List of strings or Metrics
159 metric_categories: str or list
160 One or more metric categories to use to constrain string-based metric search
161 (see modules.metrics.find_single_metric). The list
162 of metric categories is stored in modules.constants_metrics.METRIC_CATEGORIES
164 Returns
165 -------
166 processed_metrics: dict
167 Standardized dictionary of metrics. Generally used to pass to
168 evaluators.utils.fairlearn.setup_metric_frames
169 fairness_metrics: dict
170 Standardized dictionary of fairness metrics. Used for certain evaluator functions
171 """
172 processed_metrics = {}
173 fairness_metrics = {}
174 metric_categories_to_include = MODEL_METRIC_CATEGORIES.copy()
175 if metric_categories is not None:
176 metric_categories_to_include += wrap_list(metric_categories)
177 else:
178 metric_categories_to_include += MODEL_TYPES
180 for metric in metrics:
181 if isinstance(metric, str):
182 metric_name = metric
183 metric = find_single_metric(metric, metric_categories_to_include)
184 else:
185 metric_name = metric.name
186 if not isinstance(metric, Metric):
187 raise ValidationError(
188 "Specified metric is not of type credoai.metric.Metric"
189 )
190 if metric.metric_category == "FAIRNESS":
191 fairness_metrics[metric_name] = metric
192 else:
193 processed_metrics[metric_name] = metric
194 return processed_metrics, fairness_metrics
197# Convert To List of Metrics
198BINARY_CLASSIFICATION_METRICS = metrics_from_dict(
199 BINARY_CLASSIFICATION_FUNCTIONS,
200 "binary_classification",
201 PROBABILITY_FUNCTIONS,
202 METRIC_EQUIVALENTS,
203)
205MULTICLASS_CLASSIFICATION_METRICS = metrics_from_dict(
206 MULTICLASS_CLASSIFICATION_FUNCTIONS,
207 "MULTICLASS_CLASSIFICATION",
208 PROBABILITY_FUNCTIONS,
209 METRIC_EQUIVALENTS,
210)
212THRESHOLD_VARYING_METRICS = metrics_from_dict(
213 THRESHOLD_PROBABILITY_FUNCTIONS,
214 "BINARY_CLASSIFICATION_THRESHOLD",
215 THRESHOLD_PROBABILITY_FUNCTIONS,
216 THRESHOLD_METRIC_EQUIVALENTS,
217)
219REGRESSION_METRICS = metrics_from_dict(
220 REGRESSION_FUNCTIONS, "REGRESSION", PROBABILITY_FUNCTIONS, METRIC_EQUIVALENTS
221)
223FAIRNESS_METRICS = metrics_from_dict(
224 FAIRNESS_FUNCTIONS, "FAIRNESS", PROBABILITY_FUNCTIONS, METRIC_EQUIVALENTS
225)
227DATASET_METRICS = {m: Metric(m, "DATASET", None, False) for m in DATASET_METRIC_TYPES}
229PRIVACY_METRICS = {m: Metric(m, "PRIVACY", None, False) for m in PRIVACY_METRIC_TYPES}
231SECURITY_METRICS = {
232 m: Metric(m, "SECURITY", None, False) for m in SECURITY_METRIC_TYPES
233}
236METRIC_NAMES = (
237 list(BINARY_CLASSIFICATION_METRICS.keys())
238 + list(THRESHOLD_VARYING_METRICS.keys())
239 + list(FAIRNESS_METRICS.keys())
240 + list(DATASET_METRICS.keys())
241 + list(PRIVACY_METRICS.keys())
242 + list(SECURITY_METRICS.keys())
243 + list(REGRESSION_METRICS.keys())
244)
246ALL_METRICS = (
247 list(BINARY_CLASSIFICATION_METRICS.values())
248 + list(MULTICLASS_CLASSIFICATION_METRICS.values())
249 + list(THRESHOLD_VARYING_METRICS.values())
250 + list(FAIRNESS_METRICS.values())
251 + list(DATASET_METRICS.values())
252 + list(PRIVACY_METRICS.values())
253 + list(SECURITY_METRICS.values())
254 + list(REGRESSION_METRICS.values())
255)