Coverage for credoai/artifacts/model/classification_model.py: 100%

20 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-08 07:32 +0000

1"""Model artifact wrapping any classification model""" 

2from .base_model import Model 

3 

4PREDICT_PROBA_FRAMEWORKS = ["sklearn", "xgboost"] 

5 

6 

7class ClassificationModel(Model): 

8 """Class wrapper around classification model to be assessed 

9 

10 ClassificationModel serves as an adapter between arbitrary binary or multi-class 

11 classification models and the evaluations in Lens. Evaluations depend on 

12 ClassificationModel instantiating `predict` and (optionally) `predict_proba` 

13 

14 Parameters 

15 ---------- 

16 name : str 

17 Label of the model 

18 model_like : model_like 

19 A binary or multi-class classification model or pipeline. It must have a 

20 `predict` function that returns array containing the class labels for each sample. 

21 It can also optionally have a `predict_proba` function that returns array containing 

22 the class labels probabilities for each sample. 

23 """ 

24 

25 def __init__(self, name: str, model_like=None, tags=None): 

26 super().__init__( 

27 "Classification", 

28 ["predict", "predict_proba"], 

29 ["predict"], 

30 name, 

31 model_like, 

32 tags, 

33 ) 

34 

35 def _update_functionality(self): 

36 """Conditionally updates functionality based on framework""" 

37 if self.model_info["framework"] in PREDICT_PROBA_FRAMEWORKS: 

38 func = getattr(self, "predict_proba", None) 

39 if func and len(self.model_like.classes_) == 2: 

40 self.__dict__["predict_proba"] = lambda x: func(x)[:, 1] 

41 

42 

43class DummyClassifier: 

44 """Class wrapper around classification model predictions 

45 

46 This class can be used when a classification model is not available but its outputs are. 

47 The output include the array containing the predicted class labels and/or the array 

48 containing the class labels probabilities. 

49 Wrap the outputs with this class into a dummy classifier and pass it as 

50 the model to `ClassificationModel`. 

51 

52 Parameters 

53 ---------- 

54 predict_output : array 

55 Array containing the output of a model's "predict" method 

56 predict_proba_output : array 

57 Array containing the output of a model's "predict_proba" method 

58 """ 

59 

60 def __init__( 

61 self, name: str, predict_output=None, predict_proba_output=None, tags=None 

62 ): 

63 self.predict_output = predict_output 

64 self.predict_proba_output = predict_proba_output 

65 self.name = name 

66 self.tags = tags 

67 

68 def predict(self, X=None): 

69 return self.predict_output 

70 

71 def predict_proba(self, X=None): 

72 return self.predict_proba_output