Coverage for credoai/evaluators/utils/fairlearn.py: 94%
18 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-13 21:56 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-13 21:56 +0000
1from fairlearn.metrics import MetricFrame
3from credoai.modules.constants_metrics import THRESHOLD_METRIC_CATEGORIES
4from credoai.utils import global_logger, wrap_list
6########### General functions shared across evaluators ###########
9def create_metric_frame(metrics, y_pred, y_true, sensitive_features):
10 """Creates metric frame from dictionary of key:Metric"""
11 metrics = {name: metric.fun for name, metric in metrics.items()}
12 return MetricFrame(
13 metrics=metrics,
14 y_true=y_true,
15 y_pred=y_pred,
16 sensitive_features=sensitive_features,
17 )
20def filter_processed_metrics(
21 processed_metrics, metric_categories=None, Xmetric_categories=None, takes_prob=None
22):
23 """
24 Filters processed metrics
26 If any argument is None, it will be ignored for filtering
28 Parameters
29 ----------
30 metric_categories: dict
31 Dictionary of metrics (dict of str: Metric)
32 metric_categories: str or list
33 Positive metric categories to filter metrics. Each metric must have a metric_category
34 within this list. The list of metric categories is stored in modules.constants_metrics.METRIC_CATEGORIES
35 Xmetric_categories: str or list
36 Negative metric categories to filter metrics. Each metric must have a metric_category
37 NOT within this list. The list of metric categories is stored in modules.constants_metrics.METRIC_CATEGORIES
38 takes_prob: bool
39 Whether the metric takes probabilities
40 """
41 metric_categories = wrap_list(metric_categories)
42 return {
43 name: metric
44 for name, metric in processed_metrics.items()
45 if (metric_categories is None or metric.metric_category in metric_categories)
46 and (
47 Xmetric_categories is None
48 or metric.metric_category not in Xmetric_categories
49 )
50 and (takes_prob is None or metric.takes_prob == takes_prob)
51 }
54def setup_metric_frames(
55 processed_metrics,
56 y_pred,
57 y_prob,
58 y_true,
59 sensitive_features,
60):
61 metric_frames = {}
63 # tuple structure: (metric frame name, y_input, dictionary of metrics)
64 metric_frame_tuples = [
65 ("pred", y_pred, filter_processed_metrics(processed_metrics, takes_prob=False)),
66 (
67 "prob",
68 y_prob,
69 filter_processed_metrics(
70 processed_metrics,
71 Xmetric_categories=THRESHOLD_METRIC_CATEGORIES,
72 takes_prob=True,
73 ),
74 ),
75 (
76 "thresh",
77 y_prob,
78 filter_processed_metrics(
79 processed_metrics,
80 metric_categories=THRESHOLD_METRIC_CATEGORIES,
81 takes_prob=True,
82 ),
83 ),
84 ]
86 for name, y, metrics in metric_frame_tuples:
87 if metrics:
88 if y is not None:
89 metric_frames[name] = create_metric_frame(
90 metrics, y, y_true, sensitive_features
91 )
92 else:
93 global_logger.warn(
94 f"Metrics ({list(metrics.keys())}) requested for {name} metric frame, but no appropriate y available"
95 )
97 return metric_frames