Coverage for credoai/modules/stats.py: 28%
39 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-08 07:32 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-08 07:32 +0000
1from itertools import product
3import pandas as pd
4from lifelines import CoxPHFitter
6from credoai.modules.stats_utils import columns_from_formula
9class CoxPH:
10 def __init__(self, **kwargs):
11 self.name = "Cox Proportional Hazard"
12 self.cph = CoxPHFitter(**kwargs)
13 self.fit_kwargs = {}
14 self.data = None
16 def fit(self, data, **fit_kwargs):
17 self.cph.fit(data, **fit_kwargs)
18 self.fit_kwargs = fit_kwargs
19 self.data = data
20 if "formula" in fit_kwargs:
21 self.name += f" (formula: {fit_kwargs['formula']})"
22 return self
24 def summary(self):
25 s = self.cph.summary
26 s.name = f"{self.name} Stat Summary"
27 return s
29 def expected_survival(self):
30 prediction_data = self._get_prediction_data()
31 expected_predictions = self.cph.predict_expectation(prediction_data)
32 expected_predictions.name = "E(time survive)"
33 final = pd.concat([prediction_data, expected_predictions], axis=1)
34 final.name = f"{self.name} Expected Survival"
35 return final
37 def survival_curves(self):
38 prediction_data = self._get_prediction_data()
39 survival_curves = self.cph.predict_survival_function(prediction_data)
40 survival_curves = (
41 # fmt: off
42 survival_curves.loc[0:,]
43 # fmt: on
44 .rename_axis("time_step")
45 .reset_index()
46 .melt(id_vars=["time_step"])
47 .merge(right=prediction_data, left_on="variable", right_index=True)
48 .drop(columns=["variable"])
49 )
50 survival_curves = survival_curves[survival_curves["time_step"] % 5 == 0]
51 survival_curves.name = f"{self.name} Survival Curves"
52 return survival_curves
54 def _get_prediction_data(self):
55 columns = columns_from_formula(self.fit_kwargs.get("formula"))
56 df = pd.DataFrame(
57 list(product(*[i.unique() for _, i in self.data[columns].iteritems()])),
58 columns=columns,
59 )
60 return df