Coverage for credoai/evaluators/survival_fairness.py: 25%

67 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-08 07:32 +0000

1from credoai.artifacts import TabularData 

2from credoai.evaluators import Evaluator 

3from credoai.evaluators.utils.validation import ( 

4 check_data_instance, 

5 check_existence, 

6) 

7from connect.evidence import TableContainer 

8from credoai.modules import CoxPH 

9from credoai.modules.stats_utils import columns_from_formula 

10from credoai.utils import ValidationError 

11 

12 

13class SurvivalFairness(Evaluator): 

14 """ 

15 Calculate Survival fairness 

16 

17 Parameters 

18 ---------- 

19 CoxPh_kwargs : _type_, optional 

20 _description_, by default None 

21 confounds : _type_, optional 

22 _description_, by default None 

23 

24 """ 

25 

26 required_artifacts = ["model", "assessment_data", "sensitive_feature"] 

27 

28 def __init__(self, CoxPh_kwargs=None, confounds=None): 

29 if CoxPh_kwargs is None: 

30 CoxPh_kwargs = {"duration_col": "duration", "event_col": "event"} 

31 self.coxPh_kwargs = CoxPh_kwargs 

32 self.confounds = confounds 

33 self.stats = [] 

34 

35 def _validate_arguments(self): 

36 check_data_instance(self.assessment_data, TabularData) 

37 check_existence(self.assessment_data.sensitive_features, "sensitive_features") 

38 # check for columns existences 

39 expected_columns = None 

40 if self.confounds: 

41 expected_columns = set(self.confounds) 

42 if "formula" in self.coxPh_kwargs: 

43 expected_columns |= columns_from_formula(self.coxPh_kwargs["formula"]) 

44 expected_columns -= {"predictions", "predicted_probabilities"} 

45 if expected_columns is not None: 

46 missing_columns = expected_columns.difference(self.assessment_data.X) 

47 if missing_columns: 

48 raise ValidationError( 

49 f"Columns supplied to CoxPh formula not found in data. Columns are: {missing_columns}" 

50 ) 

51 

52 def _setup(self): 

53 self.y_pred = self.model.predict(self.assessment_data.X) 

54 self.sensitive_name = self.assessment_data.sensitive_feature.name 

55 self.survival_df = self.assessment_data.X.copy() 

56 self.survival_df["predictions"] = self.y_pred 

57 self.survival_df = self.survival_df.join(self.assessment_data.sensitive_feature) 

58 # add probabilities 

59 try: 

60 self.y_prob = self.model.predict_proba(self.assessment_data.X) 

61 self.survival_df["predicted_probabilities"] = self.y_prob 

62 except: 

63 self.y_prob = None 

64 return self 

65 

66 def evaluate(self): 

67 self._run_survival_analyses() 

68 result_dfs = ( 

69 self._get_summaries() 

70 + self._get_expected_survival() 

71 + self._get_survival_curves() 

72 ) 

73 sens_feat_label = {"sensitive_feature": self.sensitive_name} 

74 self.results = [ 

75 TableContainer(df, **self.get_container_info(labels=sens_feat_label)) 

76 for df in result_dfs 

77 ] 

78 return self 

79 

80 def _run_survival_analyses(self): 

81 if "formula" in self.coxPh_kwargs: 

82 cph = CoxPH() 

83 cph.fit(self.survival_df, **self.coxPh_kwargs) 

84 self.stats.append(cph) 

85 return 

86 

87 model_predictions = ( 

88 ["predictions", "predicted_probabilities"] 

89 if self.y_prob is not None 

90 else ["predictions"] 

91 ) 

92 for pred in model_predictions: 

93 run_kwargs = self.coxPh_kwargs.copy() 

94 run_kwargs["formula"] = f"{self.sensitive_name} * {pred}" 

95 if self.confounds: 

96 run_kwargs["formula"] += " + ".join(["", *self.confounds]) 

97 cph = CoxPH() 

98 cph.fit(self.survival_df, **run_kwargs) 

99 self.stats.append(cph) 

100 

101 def _get_expected_survival(self): 

102 return [s.expected_survival() for s in self.stats] 

103 

104 def _get_summaries(self): 

105 return [s.summary() for s in self.stats] 

106 

107 def _get_survival_curves(self): 

108 return [s.survival_curves() for s in self.stats]