Coverage for credoai/utils/model_utils.py: 63%
30 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-08 07:32 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-08 07:32 +0000
1import warnings
3from sklearn.base import is_classifier, is_regressor
4from sklearn.ensemble import RandomForestClassifier
5from sklearn.utils import multiclass
8def get_generic_classifier():
9 with warnings.catch_warnings():
10 warnings.simplefilter(action="ignore", category=FutureWarning)
11 try:
12 import xgboost as xgb
14 try:
15 model = xgb.XGBClassifier(
16 use_label_encoder=False, eval_metric="logloss"
17 )
18 except xgb.core.XGBoostError:
19 model = RandomForestClassifier()
20 except ModuleNotFoundError:
21 model = RandomForestClassifier()
22 return model
25def get_model_info(model):
26 try:
27 framework = model.__class__.__module__.split(".")[0]
28 except AttributeError:
29 framework = None
30 return {"framework": framework}
33def get_default_metrics(model):
34 if is_classifier(model):
35 return ["accuracy_score"]
36 elif is_regressor(model):
37 return ["r2_score"]
38 else:
39 return None
42def type_of_target(target):
43 return multiclass.type_of_target(target) if target is not None else None