Coverage for credoai/evaluators/survival_fairness.py: 33%

67 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-13 21:56 +0000

1from connect.evidence import TableContainer 

2 

3from credoai.artifacts import TabularData 

4from credoai.evaluators.evaluator import Evaluator 

5from credoai.evaluators.utils.validation import check_data_instance, check_existence 

6from credoai.modules import CoxPH 

7from credoai.modules.stats_utils import columns_from_formula 

8from credoai.utils import ValidationError 

9 

10 

11class SurvivalFairness(Evaluator): 

12 """ 

13 Calculate Survival fairness (Experimental) 

14 

15 Parameters 

16 ---------- 

17 CoxPh_kwargs : _type_, optional 

18 _description_, by default None 

19 confounds : _type_, optional 

20 _description_, by default None 

21 

22 """ 

23 

24 required_artifacts = ["model", "assessment_data", "sensitive_feature"] 

25 

26 def __init__(self, CoxPh_kwargs=None, confounds=None): 

27 if CoxPh_kwargs is None: 

28 CoxPh_kwargs = {"duration_col": "duration", "event_col": "event"} 

29 self.coxPh_kwargs = CoxPh_kwargs 

30 self.confounds = confounds 

31 self.stats = [] 

32 

33 def _validate_arguments(self): 

34 check_data_instance(self.assessment_data, TabularData) 

35 check_existence(self.assessment_data.sensitive_features, "sensitive_features") 

36 # check for columns existences 

37 expected_columns = None 

38 if self.confounds: 

39 expected_columns = set(self.confounds) 

40 if "formula" in self.coxPh_kwargs: 

41 expected_columns |= columns_from_formula(self.coxPh_kwargs["formula"]) 

42 expected_columns -= {"predictions", "predicted_probabilities"} 

43 if expected_columns is not None: 

44 missing_columns = expected_columns.difference(self.assessment_data.X) 

45 if missing_columns: 

46 raise ValidationError( 

47 f"Columns supplied to CoxPh formula not found in data. Columns are: {missing_columns}" 

48 ) 

49 

50 def _setup(self): 

51 self.y_pred = self.model.predict(self.assessment_data.X) 

52 self.sensitive_name = self.assessment_data.sensitive_feature.name 

53 self.survival_df = self.assessment_data.X.copy() 

54 self.survival_df["predictions"] = self.y_pred 

55 self.survival_df = self.survival_df.join(self.assessment_data.sensitive_feature) 

56 # add probabilities 

57 try: 

58 self.y_prob = self.model.predict_proba(self.assessment_data.X) 

59 self.survival_df["predicted_probabilities"] = self.y_prob 

60 except: 

61 self.y_prob = None 

62 return self 

63 

64 def evaluate(self): 

65 self._run_survival_analyses() 

66 result_dfs = ( 

67 self._get_summaries() 

68 + self._get_expected_survival() 

69 + self._get_survival_curves() 

70 ) 

71 sens_feat_label = {"sensitive_feature": self.sensitive_name} 

72 self.results = [ 

73 TableContainer(df, **self.get_info(labels=sens_feat_label)) 

74 for df in result_dfs 

75 ] 

76 return self 

77 

78 def _run_survival_analyses(self): 

79 if "formula" in self.coxPh_kwargs: 

80 cph = CoxPH() 

81 cph.fit(self.survival_df, **self.coxPh_kwargs) 

82 self.stats.append(cph) 

83 return 

84 

85 model_predictions = ( 

86 ["predictions", "predicted_probabilities"] 

87 if self.y_prob is not None 

88 else ["predictions"] 

89 ) 

90 for pred in model_predictions: 

91 run_kwargs = self.coxPh_kwargs.copy() 

92 run_kwargs["formula"] = f"{self.sensitive_name} * {pred}" 

93 if self.confounds: 

94 run_kwargs["formula"] += " + ".join(["", *self.confounds]) 

95 cph = CoxPH() 

96 cph.fit(self.survival_df, **run_kwargs) 

97 self.stats.append(cph) 

98 

99 def _get_expected_survival(self): 

100 return [s.expected_survival() for s in self.stats] 

101 

102 def _get_summaries(self): 

103 return [s.summary() for s in self.stats] 

104 

105 def _get_survival_curves(self): 

106 return [s.survival_curves() for s in self.stats]