Coverage for credoai/evaluators/survival_fairness.py: 25%
67 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-08 07:32 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-08 07:32 +0000
1from credoai.artifacts import TabularData
2from credoai.evaluators import Evaluator
3from credoai.evaluators.utils.validation import (
4 check_data_instance,
5 check_existence,
6)
7from connect.evidence import TableContainer
8from credoai.modules import CoxPH
9from credoai.modules.stats_utils import columns_from_formula
10from credoai.utils import ValidationError
13class SurvivalFairness(Evaluator):
14 """
15 Calculate Survival fairness
17 Parameters
18 ----------
19 CoxPh_kwargs : _type_, optional
20 _description_, by default None
21 confounds : _type_, optional
22 _description_, by default None
24 """
26 required_artifacts = ["model", "assessment_data", "sensitive_feature"]
28 def __init__(self, CoxPh_kwargs=None, confounds=None):
29 if CoxPh_kwargs is None:
30 CoxPh_kwargs = {"duration_col": "duration", "event_col": "event"}
31 self.coxPh_kwargs = CoxPh_kwargs
32 self.confounds = confounds
33 self.stats = []
35 def _validate_arguments(self):
36 check_data_instance(self.assessment_data, TabularData)
37 check_existence(self.assessment_data.sensitive_features, "sensitive_features")
38 # check for columns existences
39 expected_columns = None
40 if self.confounds:
41 expected_columns = set(self.confounds)
42 if "formula" in self.coxPh_kwargs:
43 expected_columns |= columns_from_formula(self.coxPh_kwargs["formula"])
44 expected_columns -= {"predictions", "predicted_probabilities"}
45 if expected_columns is not None:
46 missing_columns = expected_columns.difference(self.assessment_data.X)
47 if missing_columns:
48 raise ValidationError(
49 f"Columns supplied to CoxPh formula not found in data. Columns are: {missing_columns}"
50 )
52 def _setup(self):
53 self.y_pred = self.model.predict(self.assessment_data.X)
54 self.sensitive_name = self.assessment_data.sensitive_feature.name
55 self.survival_df = self.assessment_data.X.copy()
56 self.survival_df["predictions"] = self.y_pred
57 self.survival_df = self.survival_df.join(self.assessment_data.sensitive_feature)
58 # add probabilities
59 try:
60 self.y_prob = self.model.predict_proba(self.assessment_data.X)
61 self.survival_df["predicted_probabilities"] = self.y_prob
62 except:
63 self.y_prob = None
64 return self
66 def evaluate(self):
67 self._run_survival_analyses()
68 result_dfs = (
69 self._get_summaries()
70 + self._get_expected_survival()
71 + self._get_survival_curves()
72 )
73 sens_feat_label = {"sensitive_feature": self.sensitive_name}
74 self.results = [
75 TableContainer(df, **self.get_container_info(labels=sens_feat_label))
76 for df in result_dfs
77 ]
78 return self
80 def _run_survival_analyses(self):
81 if "formula" in self.coxPh_kwargs:
82 cph = CoxPH()
83 cph.fit(self.survival_df, **self.coxPh_kwargs)
84 self.stats.append(cph)
85 return
87 model_predictions = (
88 ["predictions", "predicted_probabilities"]
89 if self.y_prob is not None
90 else ["predictions"]
91 )
92 for pred in model_predictions:
93 run_kwargs = self.coxPh_kwargs.copy()
94 run_kwargs["formula"] = f"{self.sensitive_name} * {pred}"
95 if self.confounds:
96 run_kwargs["formula"] += " + ".join(["", *self.confounds])
97 cph = CoxPH()
98 cph.fit(self.survival_df, **run_kwargs)
99 self.stats.append(cph)
101 def _get_expected_survival(self):
102 return [s.expected_survival() for s in self.stats]
104 def _get_summaries(self):
105 return [s.summary() for s in self.stats]
107 def _get_survival_curves(self):
108 return [s.survival_curves() for s in self.stats]