Coverage for credoai/evaluators/fairness.py: 92%
145 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-08 07:32 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-08 07:32 +0000
1from collections import defaultdict
2from typing import List
4import pandas as pd
5from connect.evidence import MetricContainer, TableContainer
7from credoai.artifacts import TabularData
8from credoai.evaluators import Evaluator
9from credoai.evaluators.utils.fairlearn import setup_metric_frames
10from credoai.evaluators.utils.validation import (
11 check_artifact_for_nulls,
12 check_data_instance,
13 check_existence,
14)
15from credoai.modules.constants_metrics import (
16 MODEL_METRIC_CATEGORIES,
17 THRESHOLD_METRIC_CATEGORIES,
18)
19from credoai.modules.metrics import Metric, find_metrics
20from credoai.utils.common import ValidationError
23class ModelFairness(Evaluator):
24 """
25 Model Fairness evaluator for Credo AI.
27 This evaluator calculates performance metrics disaggregated by a sensitive feature, as
28 well as evaluating the parity of those metrics.
30 Handles any metric that can be calculated on a set of ground truth labels and predictions,
31 e.g., binary classification, multiclass classification, regression.
34 Parameters
35 ----------
36 metrics : List-like
37 list of metric names as string or list of Metrics (credoai.metrics.Metric).
38 Metric strings should in list returned by credoai.modules.list_metrics.
39 Note for performance parity metrics like
40 "false negative rate parity" just list "false negative rate". Parity metrics
41 are calculated automatically if the performance metric is supplied
42 method : str, optional
43 How to compute the differences: "between_groups" or "to_overall".
44 See fairlearn.metrics.MetricFrame.difference
45 for details, by default 'between_groups'
46 """
48 required_artifacts = {"model", "data", "sensitive_feature"}
50 def __init__(
51 self,
52 metrics=None,
53 method="between_groups",
54 ):
55 self.metrics = metrics
56 self.fairness_method = method
57 self.fairness_metrics = None
58 self.fairness_prob_metrics = None
59 super().__init__()
61 def _validate_arguments(self):
62 check_existence(self.metrics, "metrics")
63 check_data_instance(self.data, TabularData)
64 check_existence(self.data.sensitive_features, "sensitive_features")
65 check_artifact_for_nulls(self.data, "Data")
67 def _setup(self):
68 self.sensitive_features = self.data.sensitive_feature
69 self.y_true = self.data.y
70 self.y_pred = self.model.predict(self.data.X)
71 if hasattr(self.model, "predict_proba"):
72 self.y_prob = self.model.predict_proba(self.data.X)
73 else:
74 self.y_prob = (None,)
75 self.update_metrics(self.metrics)
76 self.sens_feat_label = {"sensitive_feature": self.sensitive_features.name}
78 def evaluate(self):
79 """
80 Run fairness base module.
81 """
82 fairness_results = self.get_fairness_results()
83 disaggregated_metrics = self.get_disaggregated_performance()
84 disaggregated_thresh_results = self.get_disaggregated_threshold_performance()
86 results = []
87 for result_obj in [
88 fairness_results,
89 disaggregated_metrics,
90 disaggregated_thresh_results,
91 ]:
92 if result_obj is not None:
93 try:
94 results += result_obj
95 except TypeError:
96 results.append(result_obj)
98 self.results = results
99 return self
101 def update_metrics(self, metrics, replace=True):
102 """
103 Replace metrics
105 Parameters
106 ----------
107 metrics : List-like
108 list of metric names as string or list of Metrics (credoai.metrics.Metric).
109 Metric strings should in list returned by credoai.modules.list_metrics.
110 Note for performance parity metrics like
111 "false negative rate parity" just list "false negative rate". Parity metrics
112 are calculated automatically if the performance metric is supplied
113 """
114 if replace:
115 self.metrics = metrics
116 else:
117 self.metrics += metrics
118 (
119 self.performance_metrics,
120 self.prob_metrics,
121 self.threshold_metrics,
122 self.fairness_metrics,
123 self.failed_metrics,
124 ) = self._process_metrics(self.metrics)
125 self.metric_frames = setup_metric_frames(
126 self.performance_metrics,
127 self.prob_metrics,
128 self.threshold_metrics,
129 self.y_pred,
130 self.y_prob,
131 self.y_true,
132 self.sensitive_features,
133 )
135 def get_disaggregated_performance(self):
136 """
137 Return performance metrics for each group
139 Parameters
140 ----------
141 melt : bool, optional
142 If True, return a long-form dataframe, by default False
144 Returns
145 -------
146 TableContainer
147 The disaggregated performance metrics
148 """
149 disaggregated_df = pd.DataFrame()
150 for name, metric_frame in self.metric_frames.items():
151 if name == "thresh":
152 continue
153 df = metric_frame.by_group.copy().convert_dtypes()
154 disaggregated_df = pd.concat([disaggregated_df, df], axis=1)
156 if disaggregated_df.empty:
157 self.logger.warn("No disaggregated metrics could be calculated.")
158 return
160 # reshape
161 disaggregated_results = disaggregated_df.reset_index().melt(
162 id_vars=[disaggregated_df.index.name],
163 var_name="type",
164 )
165 disaggregated_results.name = "disaggregated_performance"
167 metric_type_label = {
168 "metric_types": disaggregated_results.type.unique().tolist()
169 }
171 return TableContainer(
172 disaggregated_results,
173 **self.get_container_info(
174 labels={**self.sens_feat_label, **metric_type_label}
175 ),
176 )
178 def get_disaggregated_threshold_performance(self):
179 """
180 Return performance metrics for each group
182 Parameters
183 ----------
184 melt : bool, optional
185 If True, return a long-form dataframe, by default False
187 Returns
188 -------
189 List[TableContainer]
190 The disaggregated performance metrics
191 """
192 metric_frame = self.metric_frames.get("thresh")
193 if metric_frame is None:
194 return
195 df = metric_frame.by_group.copy().convert_dtypes()
197 df = df.reset_index().melt(
198 id_vars=[df.index.name],
199 var_name="type",
200 )
202 to_return = defaultdict(list)
203 for i, row in df.iterrows():
204 tmp_df = row["value"]
205 tmp_df = tmp_df.assign(**row.drop("value"))
206 to_return[row["type"]].append(tmp_df)
207 for key in to_return.keys():
208 df = pd.concat(to_return[key])
209 df.name = "threshold_dependent_disaggregated_performance"
210 to_return[key] = df
212 disaggregated_thresh_results = []
213 for key, df in to_return.items():
214 labels = {**self.sens_feat_label, **{"metric_type": key}}
215 disaggregated_thresh_results.append(
216 TableContainer(df, **self.get_container_info(labels=labels))
217 )
219 return disaggregated_thresh_results
221 def get_fairness_results(self):
222 """Return fairness and performance parity metrics
224 Note, performance parity metrics are labeled with their
225 related performance label, but are computed using
226 fairlearn.metrics.MetricFrame.difference(method)
228 Returns
229 -------
230 MetricContainer
231 The returned fairness metrics
232 """
234 results = []
235 for metric_name, metric in self.fairness_metrics.items():
236 pred_argument = {"y_pred": self.y_pred}
237 if metric.takes_prob:
238 pred_argument = {"y_prob": self.y_prob}
239 try:
240 metric_value = metric.fun(
241 y_true=self.y_true,
242 sensitive_features=self.sensitive_features,
243 method=self.fairness_method,
244 **pred_argument,
245 )
246 except Exception as e:
247 self.logger.error(
248 f"A metric ({metric_name}) failed to run. "
249 "Are you sure it works with this kind of model and target?\n"
250 )
251 raise e
252 results.append({"metric_type": metric_name, "value": metric_value})
254 results = pd.DataFrame.from_dict(results)
256 # add parity results
257 parity_results = pd.Series(dtype=float)
258 parity_results = []
259 for name, metric_frame in self.metric_frames.items():
260 if name == "thresh":
261 # Don't calculate difference for curve metrics. This is not mathematically well-defined.
262 continue
263 diffs = metric_frame.difference(self.fairness_method).rename(
264 "{}_parity".format
265 )
266 diffs = pd.DataFrame({"metric_type": diffs.index, "value": diffs.values})
267 parity_results.append(diffs)
269 if parity_results:
270 parity_results = pd.concat(parity_results)
271 results = pd.concat([results, parity_results])
273 results.rename({"metric_type": "type"}, axis=1, inplace=True)
275 if results.empty:
276 self.logger.info("No fairness metrics calculated.")
277 return
278 return MetricContainer(
279 results,
280 **self.get_container_info(labels=self.sens_feat_label),
281 )
283 def _process_metrics(self, metrics):
284 """
285 Separates metrics
287 Parameters
288 ----------
289 metrics : Union[List[Metric, str]]
290 list of metrics to use. These can be Metric objects (see credoai.modules.metrics.py), or strings.
291 If strings, they will be converted to Metric objects using find_metrics
293 Returns
294 -------
295 Separate dictionaries and lists of metrics
296 """
297 # separate metrics
298 failed_metrics = []
299 performance_metrics = {}
300 prob_metrics = {}
301 threshold_metrics = {}
302 fairness_metrics = {}
303 fairness_prob_metrics = {}
304 for metric in metrics:
305 if isinstance(metric, str):
306 metric_name = metric
307 metric = find_metrics(metric, MODEL_METRIC_CATEGORIES)
308 if len(metric) == 1:
309 metric = metric[0]
310 elif len(metric) == 0:
311 raise Exception(
312 f"Returned no metrics when searching using the provided metric name <{metric_name}>. Expected to find one matching metric."
313 )
314 else:
315 raise Exception(
316 f"Returned multiple metrics when searching using the provided metric name <{metric_name}>. Expected to find only one matching metric."
317 )
318 else:
319 metric_name = metric.name
320 if not isinstance(metric, Metric):
321 raise ValidationError("Metric is not of type credoai.metric.Metric")
322 if metric.metric_category == "FAIRNESS":
323 fairness_metrics[metric_name] = metric
324 elif metric.metric_category in MODEL_METRIC_CATEGORIES:
325 if metric.takes_prob:
326 if metric.metric_category in THRESHOLD_METRIC_CATEGORIES:
327 threshold_metrics[metric_name] = metric
328 else:
329 prob_metrics[metric_name] = metric
330 else:
331 performance_metrics[metric_name] = metric
332 else:
333 self.logger.warning(
334 f"{metric_name} failed to be used by FairnessModule"
335 )
336 failed_metrics.append(metric_name)
338 return (
339 performance_metrics,
340 prob_metrics,
341 threshold_metrics,
342 fairness_metrics,
343 failed_metrics,
344 )