Coverage for credoai/evaluators/survival_fairness.py: 33%
67 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-13 21:56 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-13 21:56 +0000
1from connect.evidence import TableContainer
3from credoai.artifacts import TabularData
4from credoai.evaluators.evaluator import Evaluator
5from credoai.evaluators.utils.validation import check_data_instance, check_existence
6from credoai.modules import CoxPH
7from credoai.modules.stats_utils import columns_from_formula
8from credoai.utils import ValidationError
11class SurvivalFairness(Evaluator):
12 """
13 Calculate Survival fairness (Experimental)
15 Parameters
16 ----------
17 CoxPh_kwargs : _type_, optional
18 _description_, by default None
19 confounds : _type_, optional
20 _description_, by default None
22 """
24 required_artifacts = ["model", "assessment_data", "sensitive_feature"]
26 def __init__(self, CoxPh_kwargs=None, confounds=None):
27 if CoxPh_kwargs is None:
28 CoxPh_kwargs = {"duration_col": "duration", "event_col": "event"}
29 self.coxPh_kwargs = CoxPh_kwargs
30 self.confounds = confounds
31 self.stats = []
33 def _validate_arguments(self):
34 check_data_instance(self.assessment_data, TabularData)
35 check_existence(self.assessment_data.sensitive_features, "sensitive_features")
36 # check for columns existences
37 expected_columns = None
38 if self.confounds:
39 expected_columns = set(self.confounds)
40 if "formula" in self.coxPh_kwargs:
41 expected_columns |= columns_from_formula(self.coxPh_kwargs["formula"])
42 expected_columns -= {"predictions", "predicted_probabilities"}
43 if expected_columns is not None:
44 missing_columns = expected_columns.difference(self.assessment_data.X)
45 if missing_columns:
46 raise ValidationError(
47 f"Columns supplied to CoxPh formula not found in data. Columns are: {missing_columns}"
48 )
50 def _setup(self):
51 self.y_pred = self.model.predict(self.assessment_data.X)
52 self.sensitive_name = self.assessment_data.sensitive_feature.name
53 self.survival_df = self.assessment_data.X.copy()
54 self.survival_df["predictions"] = self.y_pred
55 self.survival_df = self.survival_df.join(self.assessment_data.sensitive_feature)
56 # add probabilities
57 try:
58 self.y_prob = self.model.predict_proba(self.assessment_data.X)
59 self.survival_df["predicted_probabilities"] = self.y_prob
60 except:
61 self.y_prob = None
62 return self
64 def evaluate(self):
65 self._run_survival_analyses()
66 result_dfs = (
67 self._get_summaries()
68 + self._get_expected_survival()
69 + self._get_survival_curves()
70 )
71 sens_feat_label = {"sensitive_feature": self.sensitive_name}
72 self.results = [
73 TableContainer(df, **self.get_info(labels=sens_feat_label))
74 for df in result_dfs
75 ]
76 return self
78 def _run_survival_analyses(self):
79 if "formula" in self.coxPh_kwargs:
80 cph = CoxPH()
81 cph.fit(self.survival_df, **self.coxPh_kwargs)
82 self.stats.append(cph)
83 return
85 model_predictions = (
86 ["predictions", "predicted_probabilities"]
87 if self.y_prob is not None
88 else ["predictions"]
89 )
90 for pred in model_predictions:
91 run_kwargs = self.coxPh_kwargs.copy()
92 run_kwargs["formula"] = f"{self.sensitive_name} * {pred}"
93 if self.confounds:
94 run_kwargs["formula"] += " + ".join(["", *self.confounds])
95 cph = CoxPH()
96 cph.fit(self.survival_df, **run_kwargs)
97 self.stats.append(cph)
99 def _get_expected_survival(self):
100 return [s.expected_survival() for s in self.stats]
102 def _get_summaries(self):
103 return [s.summary() for s in self.stats]
105 def _get_survival_curves(self):
106 return [s.survival_curves() for s in self.stats]