Coverage for credoai/modules/stats.py: 28%

39 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-08 07:32 +0000

1from itertools import product 

2 

3import pandas as pd 

4from lifelines import CoxPHFitter 

5 

6from credoai.modules.stats_utils import columns_from_formula 

7 

8 

9class CoxPH: 

10 def __init__(self, **kwargs): 

11 self.name = "Cox Proportional Hazard" 

12 self.cph = CoxPHFitter(**kwargs) 

13 self.fit_kwargs = {} 

14 self.data = None 

15 

16 def fit(self, data, **fit_kwargs): 

17 self.cph.fit(data, **fit_kwargs) 

18 self.fit_kwargs = fit_kwargs 

19 self.data = data 

20 if "formula" in fit_kwargs: 

21 self.name += f" (formula: {fit_kwargs['formula']})" 

22 return self 

23 

24 def summary(self): 

25 s = self.cph.summary 

26 s.name = f"{self.name} Stat Summary" 

27 return s 

28 

29 def expected_survival(self): 

30 prediction_data = self._get_prediction_data() 

31 expected_predictions = self.cph.predict_expectation(prediction_data) 

32 expected_predictions.name = "E(time survive)" 

33 final = pd.concat([prediction_data, expected_predictions], axis=1) 

34 final.name = f"{self.name} Expected Survival" 

35 return final 

36 

37 def survival_curves(self): 

38 prediction_data = self._get_prediction_data() 

39 survival_curves = self.cph.predict_survival_function(prediction_data) 

40 survival_curves = ( 

41 # fmt: off 

42 survival_curves.loc[0:,] 

43 # fmt: on 

44 .rename_axis("time_step") 

45 .reset_index() 

46 .melt(id_vars=["time_step"]) 

47 .merge(right=prediction_data, left_on="variable", right_index=True) 

48 .drop(columns=["variable"]) 

49 ) 

50 survival_curves = survival_curves[survival_curves["time_step"] % 5 == 0] 

51 survival_curves.name = f"{self.name} Survival Curves" 

52 return survival_curves 

53 

54 def _get_prediction_data(self): 

55 columns = columns_from_formula(self.fit_kwargs.get("formula")) 

56 df = pd.DataFrame( 

57 list(product(*[i.unique() for _, i in self.data[columns].iteritems()])), 

58 columns=columns, 

59 ) 

60 return df