Coverage for credoai/evaluators/fairness.py: 98%
125 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-13 21:56 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-13 21:56 +0000
1from collections import defaultdict
2from typing import Optional
4import pandas as pd
5from connect.evidence import MetricContainer, TableContainer
7from credoai.artifacts import ClassificationModel
8from credoai.evaluators.evaluator import Evaluator
9from credoai.evaluators.performance import create_confusion_matrix
10from credoai.evaluators.utils.fairlearn import setup_metric_frames
11from credoai.evaluators.utils.validation import check_data_for_nulls, check_existence
12from credoai.modules.metrics import process_metrics
15class ModelFairness(Evaluator):
16 """
17 Model Fairness evaluator for Credo AI.
19 This evaluator calculates performance metrics disaggregated by a sensitive feature, as
20 well as evaluating the parity of those metrics.
22 Handles any metric that can be calculated on a set of ground truth labels and predictions,
23 e.g., binary classification, multiclass classification, regression.
26 Parameters
27 ----------
28 metrics : List-like
29 list of metric names as string or list of Metrics (credoai.metrics.Metric).
30 Metric strings should in list returned by credoai.modules.list_metrics.
31 Note for performance parity metrics like
32 "false negative rate parity" just list "false negative rate". Parity metrics
33 are calculated automatically if the performance metric is supplied
34 method : str, optional
35 How to compute the differences: "between_groups" or "to_overall".
36 See fairlearn.metrics.MetricFrame.difference
37 for details, by default 'between_groups'
38 """
40 required_artifacts = {"model", "data", "sensitive_feature"}
42 def __init__(
43 self,
44 metrics=None,
45 method="between_groups",
46 ):
47 self.metrics = metrics
48 self.fairness_method = method
49 self.fairness_metrics = None
50 self.fairness_prob_metrics = None
51 super().__init__()
53 def _validate_arguments(self):
54 check_existence(self.metrics, "metrics")
55 check_existence(self.data.y, "y")
56 check_data_for_nulls(
57 self.data, "Data", check_X=True, check_y=True, check_sens=True
58 )
60 def _setup(self):
61 self.sensitive_features = self.data.sensitive_feature
62 self.y_true = self.data.y
63 self.y_pred = self.model.predict(self.data.X)
64 if hasattr(self.model, "predict_proba"):
65 self.y_prob = self.model.predict_proba(self.data.X)
66 else:
67 self.y_prob = (None,)
68 self.update_metrics(self.metrics)
70 def evaluate(self):
71 """
72 Run fairness base module.
73 """
74 fairness_results = self.get_fairness_results()
75 disaggregated_metrics = self.get_disaggregated_performance()
76 disaggregated_thresh_results = self.get_disaggregated_threshold_performance()
77 confusion_matrix = self.get_confusion_matrix()
79 results = []
80 for result_obj in [
81 fairness_results,
82 disaggregated_metrics,
83 disaggregated_thresh_results,
84 confusion_matrix,
85 ]:
86 if result_obj is not None:
87 try:
88 results += result_obj
89 except TypeError:
90 results.append(result_obj)
92 self.results = results
93 return self
95 def update_metrics(self, metrics, replace=True):
96 """
97 Replace metrics
99 Parameters
100 ----------
101 metrics : List-like
102 list of metric names as string or list of Metrics (credoai.metrics.Metric).
103 Metric strings should in list returned by credoai.modules.list_metrics.
104 Note for performance parity metrics like
105 "false negative rate parity" just list "false negative rate". Parity metrics
106 are calculated automatically if the performance metric is supplied
107 """
108 if replace:
109 self.metrics = metrics
110 else:
111 self.metrics += metrics
112 self.processed_metrics, self.fairness_metrics = process_metrics(
113 self.metrics, self.model.type
114 )
115 self.metric_frames = setup_metric_frames(
116 self.processed_metrics,
117 self.y_pred,
118 self.y_prob,
119 self.y_true,
120 self.sensitive_features,
121 )
123 def get_confusion_matrix(self) -> Optional[TableContainer]:
124 """
125 Create confusion matrix if the model is a classification model.
127 This returns a confusion matrix for each subgroup within a sensitive feature.
129 Returns
130 -------
131 Optional[TableContainer]
132 Table container containing the confusion matrix. A single table is created in
133 which one of the columns (sens_feat_group) contains the label to separate the
134 the sensitive feature subgroup.
136 """
137 if not isinstance(self.model, ClassificationModel):
138 return None
140 df = pd.DataFrame(
141 {
142 "y_true": self.y_true,
143 "y_pred": self.y_pred,
144 "sens_feat": self.sensitive_features,
145 }
146 )
148 cm_disag = []
149 for group in df.groupby("sens_feat"):
150 cm = create_confusion_matrix(group[1].y_true, group[1].y_pred)
151 cm["sens_feat_group"] = group[0]
152 cm_disag.append(cm)
154 cm_disag = pd.concat(cm_disag, ignore_index=True)
155 cm_disag.name = "disaggregated_confusion_matrix"
157 return TableContainer(cm_disag, **self.get_info())
159 def get_disaggregated_performance(self):
160 """
161 Return performance metrics for each group
163 Parameters
164 ----------
165 melt : bool, optional
166 If True, return a long-form dataframe, by default False
168 Returns
169 -------
170 TableContainer
171 The disaggregated performance metrics
172 """
173 disaggregated_df = pd.DataFrame()
174 for name, metric_frame in self.metric_frames.items():
175 if name == "thresh":
176 continue
177 df = metric_frame.by_group.copy().convert_dtypes()
178 disaggregated_df = pd.concat([disaggregated_df, df], axis=1)
180 if disaggregated_df.empty:
181 self.logger.warn("No disaggregated metrics could be calculated.")
182 return
184 # reshape
185 disaggregated_results = disaggregated_df.reset_index().melt(
186 id_vars=[disaggregated_df.index.name],
187 var_name="type",
188 )
189 disaggregated_results.name = "disaggregated_performance"
191 metric_type_label = {
192 "metric_types": disaggregated_results.type.unique().tolist()
193 }
195 return TableContainer(
196 disaggregated_results,
197 **self.get_info(labels=metric_type_label),
198 )
200 def get_disaggregated_threshold_performance(self):
201 """
202 Return performance metrics for each group
204 Parameters
205 ----------
206 melt : bool, optional
207 If True, return a long-form dataframe, by default False
209 Returns
210 -------
211 List[TableContainer]
212 The disaggregated performance metrics
213 """
214 metric_frame = self.metric_frames.get("thresh")
215 if metric_frame is None:
216 return
217 df = metric_frame.by_group.copy().convert_dtypes()
219 df = df.reset_index().melt(
220 id_vars=[df.index.name],
221 var_name="type",
222 )
224 to_return = defaultdict(list)
225 for i, row in df.iterrows():
226 tmp_df = row["value"]
227 tmp_df = tmp_df.assign(**row.drop("value"))
228 to_return[row["type"]].append(tmp_df)
229 for key in to_return.keys():
230 df = pd.concat(to_return[key])
231 df.name = "threshold_dependent_disaggregated_performance"
232 to_return[key] = df
234 disaggregated_thresh_results = []
235 for key, df in to_return.items():
236 labels = {"metric_type": key}
237 disaggregated_thresh_results.append(
238 TableContainer(df, **self.get_info(labels=labels))
239 )
241 return disaggregated_thresh_results
243 def get_fairness_results(self):
244 """Return fairness and performance parity metrics
246 Note, performance parity metrics are labeled with their
247 related performance label, but are computed using
248 fairlearn.metrics.MetricFrame.difference(method)
250 Returns
251 -------
252 MetricContainer
253 The returned fairness metrics
254 """
256 results = []
257 for metric_name, metric in self.fairness_metrics.items():
258 pred_argument = {"y_pred": self.y_pred}
259 if metric.takes_prob:
260 pred_argument = {"y_prob": self.y_prob}
261 try:
262 metric_value = metric.fun(
263 y_true=self.y_true,
264 sensitive_features=self.sensitive_features,
265 method=self.fairness_method,
266 **pred_argument,
267 )
268 except Exception as e:
269 self.logger.error(
270 f"A metric ({metric_name}) failed to run. "
271 "Are you sure it works with this kind of model and target?\n"
272 )
273 raise e
274 results.append({"metric_type": metric_name, "value": metric_value})
276 results = pd.DataFrame.from_dict(results)
278 # add parity results
279 parity_results = pd.Series(dtype=float)
280 parity_results = []
281 for name, metric_frame in self.metric_frames.items():
282 if name == "thresh":
283 # Don't calculate difference for curve metrics. This is not mathematically well-defined.
284 continue
285 diffs = metric_frame.difference(self.fairness_method).rename(
286 "{}_parity".format
287 )
288 diffs = pd.DataFrame({"metric_type": diffs.index, "value": diffs.values})
289 parity_results.append(diffs)
291 if parity_results:
292 parity_results = pd.concat(parity_results)
293 results = pd.concat([results, parity_results])
295 results.rename({"metric_type": "type"}, axis=1, inplace=True)
297 if results.empty:
298 self.logger.info("No fairness metrics calculated.")
299 return
300 return MetricContainer(results, **self.get_info())