Coverage for credoai/utils/model_utils.py: 63%

30 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-08 07:32 +0000

1import warnings 

2 

3from sklearn.base import is_classifier, is_regressor 

4from sklearn.ensemble import RandomForestClassifier 

5from sklearn.utils import multiclass 

6 

7 

8def get_generic_classifier(): 

9 with warnings.catch_warnings(): 

10 warnings.simplefilter(action="ignore", category=FutureWarning) 

11 try: 

12 import xgboost as xgb 

13 

14 try: 

15 model = xgb.XGBClassifier( 

16 use_label_encoder=False, eval_metric="logloss" 

17 ) 

18 except xgb.core.XGBoostError: 

19 model = RandomForestClassifier() 

20 except ModuleNotFoundError: 

21 model = RandomForestClassifier() 

22 return model 

23 

24 

25def get_model_info(model): 

26 try: 

27 framework = model.__class__.__module__.split(".")[0] 

28 except AttributeError: 

29 framework = None 

30 return {"framework": framework} 

31 

32 

33def get_default_metrics(model): 

34 if is_classifier(model): 

35 return ["accuracy_score"] 

36 elif is_regressor(model): 

37 return ["r2_score"] 

38 else: 

39 return None 

40 

41 

42def type_of_target(target): 

43 return multiclass.type_of_target(target) if target is not None else None