Coverage for credoai/modules/metrics.py: 88%
60 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-08 07:32 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-08 07:32 +0000
1import re
2from dataclasses import dataclass
3from typing import Any, Callable, List, Optional
5from credoai.modules.constants_metrics import *
6from credoai.modules.constants_threshold_metrics import *
7from credoai.utils.common import ValidationError, humanize_label
10@dataclass
11class Metric:
12 """Class to define metrics
14 Metric categories determine what kind of use the metric is designed for. Credo AI assumes
15 that metric signatures either correspond with scikit-learn or fairlearn method signatures,
16 in the case of binary/multiclass classification, regression, clustering and fairness metrics.
18 Dataset metrics are used as documentation placeholders and define no function.
19 See DATASET_METRICS for examples. CUSTOM metrics have no expectations and will
20 not be automatically used by Lens modules.
22 Metric Categories:
24 * {BINARY|MULTICLASS}_CLASSIFICATION: metrics like `scikit-learn's classification metrics <https://scikit-learn.org/stable/modules/model_evaluation.html>`_
25 * REGRESSION: metrics like `scikit-learn's regression metrics <https://scikit-learn.org/stable/modules/model_evaluation.html>`_
26 * CLUSTERING: metrics like `scikit-learn's clustering metrics <https://scikit-learn.org/stable/modules/model_evaluation.html>`_
27 * FAIRNESS: metrics like `fairlearn's equalized odds metric <https://fairlearn.org/v0.5.0/api_reference/fairlearn.metrics.html
28 * DATASET: metrics intended
29 * CUSTOM: No expectations for fun
31 Parameters
32 ----------
33 name : str
34 The primary name of the metric
35 metric_category : str
36 defined to be one of the METRIC_CATEGORIES, above
37 fun : callable, optional
38 The function definition of the metric. If none, the metric cannot be used and is only
39 defined for documentation purposes
40 takes_prob : bool, optional
41 Whether the function takes the decision probabilities
42 instead of the predicted class, as for ROC AUC. Similar to `needs_proba` used by
43 `sklearn <https://scikit-learn.org/stable/modules/generated/sklearn.metrics.make_scorer.html>`_
44 by default False
45 equivalent_names : list
46 list of other names for metric
47 """
49 name: str
50 metric_category: str
51 fun: Optional[Callable[[Any], Any]] = None
52 takes_prob: Optional[bool] = False
53 equivalent_names: Optional[List[str]] = None
55 def __post_init__(self):
56 if self.equivalent_names is None:
57 self.equivalent_names = {self.name}
58 else:
59 self.equivalent_names = set(self.equivalent_names + [self.name])
60 self.metric_category = self.metric_category.upper()
61 if self.metric_category not in METRIC_CATEGORIES:
62 raise ValidationError(f"metric type ({self.metric_category}) isn't valid")
63 self.humanized_type = humanize_label(self.name)
65 def __call__(self, **kwargs):
66 self.fun(**kwargs)
68 def get_fun_doc(self):
69 if self.fun:
70 return self.fun.__doc__
72 def print_fun_doc(self):
73 print(self.get_fun_doc())
75 def is_metric(
76 self, metric_name: str, metric_categories: Optional[List[str]] = None
77 ):
78 metric_name = self.standardize_metric_name(metric_name)
79 if self.equivalent_names:
80 name_match = metric_name in self.equivalent_names
81 if metric_categories is not None:
82 return name_match and self.metric_category in metric_categories
83 return name_match
85 def standardize_metric_name(self, metric):
86 # standardize
87 # lower, remove spaces, replace delimiters with underscores
88 standard = "_".join(re.split("[- \s _]", re.sub("\s\s+", " ", metric.lower())))
89 return standard
92def metrics_from_dict(dict, metric_category, probability_functions, metric_equivalents):
93 # Convert to metric objects
94 metrics = {}
95 for metric_name, fun in dict.items():
96 equivalents = metric_equivalents.get(metric_name, []) # get equivalent names
97 # whether the metric takes probabities instead of predictions
98 takes_prob = metric_name in probability_functions
99 metric = Metric(metric_name, metric_category, fun, takes_prob, equivalents)
100 metrics[metric_name] = metric
101 return metrics
104def find_metrics(metric_name, metric_category=None):
105 """Find metric by name and metric category
107 Parameters
108 ----------
109 metric_name : str
110 metric name to search for
111 metric_category : str or list, optional
112 category or list of categories to constrain search to, by default None
114 Returns
115 -------
116 list
117 list of Metrics
118 """
119 if isinstance(metric_category, str):
120 metric_category = [metric_category]
121 matched_metrics = [
122 i for i in ALL_METRICS if i.is_metric(metric_name, metric_category)
123 ]
124 return matched_metrics
127# Convert To List of Metrics
128BINARY_CLASSIFICATION_METRICS = metrics_from_dict(
129 BINARY_CLASSIFICATION_FUNCTIONS,
130 "BINARY_CLASSIFICATION",
131 PROBABILITY_FUNCTIONS,
132 METRIC_EQUIVALENTS,
133)
135THRESHOLD_VARYING_METRICS = metrics_from_dict(
136 THRESHOLD_PROBABILITY_FUNCTIONS,
137 "BINARY_CLASSIFICATION_THRESHOLD",
138 THRESHOLD_PROBABILITY_FUNCTIONS,
139 THRESHOLD_METRIC_EQUIVALENTS,
140)
142REGRESSION_METRICS = metrics_from_dict(
143 REGRESSION_FUNCTIONS, "REGRESSION", PROBABILITY_FUNCTIONS, METRIC_EQUIVALENTS
144)
146FAIRNESS_METRICS = metrics_from_dict(
147 FAIRNESS_FUNCTIONS, "FAIRNESS", PROBABILITY_FUNCTIONS, METRIC_EQUIVALENTS
148)
150DATASET_METRICS = {m: Metric(m, "DATASET", None, False) for m in DATASET_METRIC_TYPES}
152PRIVACY_METRICS = {m: Metric(m, "PRIVACY", None, False) for m in PRIVACY_METRIC_TYPES}
154SECURITY_METRICS = {
155 m: Metric(m, "SECURITY", None, False) for m in SECURITY_METRIC_TYPES
156}
159METRIC_NAMES = (
160 list(BINARY_CLASSIFICATION_METRICS.keys())
161 + list(THRESHOLD_VARYING_METRICS.keys())
162 + list(FAIRNESS_METRICS.keys())
163 + list(DATASET_METRICS.keys())
164 + list(PRIVACY_METRICS.keys())
165 + list(SECURITY_METRICS.keys())
166 + list(REGRESSION_METRICS.keys())
167)
169ALL_METRICS = (
170 list(BINARY_CLASSIFICATION_METRICS.values())
171 + list(THRESHOLD_VARYING_METRICS.values())
172 + list(FAIRNESS_METRICS.values())
173 + list(DATASET_METRICS.values())
174 + list(PRIVACY_METRICS.values())
175 + list(SECURITY_METRICS.values())
176 + list(REGRESSION_METRICS.values())
177)