Coverage for credoai/artifacts/model/classification_model.py: 100%
20 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-08 07:32 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-08 07:32 +0000
1"""Model artifact wrapping any classification model"""
2from .base_model import Model
4PREDICT_PROBA_FRAMEWORKS = ["sklearn", "xgboost"]
7class ClassificationModel(Model):
8 """Class wrapper around classification model to be assessed
10 ClassificationModel serves as an adapter between arbitrary binary or multi-class
11 classification models and the evaluations in Lens. Evaluations depend on
12 ClassificationModel instantiating `predict` and (optionally) `predict_proba`
14 Parameters
15 ----------
16 name : str
17 Label of the model
18 model_like : model_like
19 A binary or multi-class classification model or pipeline. It must have a
20 `predict` function that returns array containing the class labels for each sample.
21 It can also optionally have a `predict_proba` function that returns array containing
22 the class labels probabilities for each sample.
23 """
25 def __init__(self, name: str, model_like=None, tags=None):
26 super().__init__(
27 "Classification",
28 ["predict", "predict_proba"],
29 ["predict"],
30 name,
31 model_like,
32 tags,
33 )
35 def _update_functionality(self):
36 """Conditionally updates functionality based on framework"""
37 if self.model_info["framework"] in PREDICT_PROBA_FRAMEWORKS:
38 func = getattr(self, "predict_proba", None)
39 if func and len(self.model_like.classes_) == 2:
40 self.__dict__["predict_proba"] = lambda x: func(x)[:, 1]
43class DummyClassifier:
44 """Class wrapper around classification model predictions
46 This class can be used when a classification model is not available but its outputs are.
47 The output include the array containing the predicted class labels and/or the array
48 containing the class labels probabilities.
49 Wrap the outputs with this class into a dummy classifier and pass it as
50 the model to `ClassificationModel`.
52 Parameters
53 ----------
54 predict_output : array
55 Array containing the output of a model's "predict" method
56 predict_proba_output : array
57 Array containing the output of a model's "predict_proba" method
58 """
60 def __init__(
61 self, name: str, predict_output=None, predict_proba_output=None, tags=None
62 ):
63 self.predict_output = predict_output
64 self.predict_proba_output = predict_proba_output
65 self.name = name
66 self.tags = tags
68 def predict(self, X=None):
69 return self.predict_output
71 def predict_proba(self, X=None):
72 return self.predict_proba_output