Coverage for credoai/evaluators/utils/fairlearn.py: 94%

18 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-13 21:56 +0000

1from fairlearn.metrics import MetricFrame 

2 

3from credoai.modules.constants_metrics import THRESHOLD_METRIC_CATEGORIES 

4from credoai.utils import global_logger, wrap_list 

5 

6########### General functions shared across evaluators ########### 

7 

8 

9def create_metric_frame(metrics, y_pred, y_true, sensitive_features): 

10 """Creates metric frame from dictionary of key:Metric""" 

11 metrics = {name: metric.fun for name, metric in metrics.items()} 

12 return MetricFrame( 

13 metrics=metrics, 

14 y_true=y_true, 

15 y_pred=y_pred, 

16 sensitive_features=sensitive_features, 

17 ) 

18 

19 

20def filter_processed_metrics( 

21 processed_metrics, metric_categories=None, Xmetric_categories=None, takes_prob=None 

22): 

23 """ 

24 Filters processed metrics 

25 

26 If any argument is None, it will be ignored for filtering 

27 

28 Parameters 

29 ---------- 

30 metric_categories: dict 

31 Dictionary of metrics (dict of str: Metric) 

32 metric_categories: str or list 

33 Positive metric categories to filter metrics. Each metric must have a metric_category 

34 within this list. The list of metric categories is stored in modules.constants_metrics.METRIC_CATEGORIES 

35 Xmetric_categories: str or list 

36 Negative metric categories to filter metrics. Each metric must have a metric_category 

37 NOT within this list. The list of metric categories is stored in modules.constants_metrics.METRIC_CATEGORIES 

38 takes_prob: bool 

39 Whether the metric takes probabilities 

40 """ 

41 metric_categories = wrap_list(metric_categories) 

42 return { 

43 name: metric 

44 for name, metric in processed_metrics.items() 

45 if (metric_categories is None or metric.metric_category in metric_categories) 

46 and ( 

47 Xmetric_categories is None 

48 or metric.metric_category not in Xmetric_categories 

49 ) 

50 and (takes_prob is None or metric.takes_prob == takes_prob) 

51 } 

52 

53 

54def setup_metric_frames( 

55 processed_metrics, 

56 y_pred, 

57 y_prob, 

58 y_true, 

59 sensitive_features, 

60): 

61 metric_frames = {} 

62 

63 # tuple structure: (metric frame name, y_input, dictionary of metrics) 

64 metric_frame_tuples = [ 

65 ("pred", y_pred, filter_processed_metrics(processed_metrics, takes_prob=False)), 

66 ( 

67 "prob", 

68 y_prob, 

69 filter_processed_metrics( 

70 processed_metrics, 

71 Xmetric_categories=THRESHOLD_METRIC_CATEGORIES, 

72 takes_prob=True, 

73 ), 

74 ), 

75 ( 

76 "thresh", 

77 y_prob, 

78 filter_processed_metrics( 

79 processed_metrics, 

80 metric_categories=THRESHOLD_METRIC_CATEGORIES, 

81 takes_prob=True, 

82 ), 

83 ), 

84 ] 

85 

86 for name, y, metrics in metric_frame_tuples: 

87 if metrics: 

88 if y is not None: 

89 metric_frames[name] = create_metric_frame( 

90 metrics, y, y_true, sensitive_features 

91 ) 

92 else: 

93 global_logger.warn( 

94 f"Metrics ({list(metrics.keys())}) requested for {name} metric frame, but no appropriate y available" 

95 ) 

96 

97 return metric_frames