Coverage for credoai/utils/model_utils.py: 38%
78 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-13 21:56 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-13 21:56 +0000
1import warnings
3from sklearn.base import is_classifier, is_regressor
4from sklearn.ensemble import RandomForestClassifier
5from sklearn.utils import multiclass
7from credoai.utils import global_logger
9try:
10 from tensorflow.keras import layers
11except ImportError:
12 pass
15def get_generic_classifier():
16 with warnings.catch_warnings():
17 warnings.simplefilter(action="ignore", category=FutureWarning)
18 try:
19 import xgboost as xgb
21 try:
22 model = xgb.XGBClassifier(
23 use_label_encoder=False, eval_metric="logloss"
24 )
25 except xgb.core.XGBoostError:
26 model = RandomForestClassifier()
27 except ModuleNotFoundError:
28 model = RandomForestClassifier()
29 return model
32def get_model_info(model):
33 """Returns basic information about model info"""
34 try:
35 framework = getattr(model, "framework_like", None)
36 if not framework:
37 framework = model.__class__.__module__.split(".")[0]
38 except AttributeError:
39 framework = None
40 try:
41 name = model.__class__.__name__
42 except AttributeError:
43 name = None
44 return {"framework": framework, "lib_name": name}
47def get_default_metrics(model):
48 if is_classifier(model):
49 return ["accuracy_score"]
50 elif is_regressor(model):
51 return ["r2_score"]
52 else:
53 return None
56def type_of_target(target):
57 return multiclass.type_of_target(target) if target is not None else None
60#############################################
61# Validation Functions for Various Model Types
62#############################################
63def validate_sklearn_like(model_obj, model_info: dict):
64 pass
67def validate_keras_clf(model_obj, model_info: dict):
68 # This is how Keras checks sequential too: https://github.com/keras-team/keras/blob/master/keras/utils/layer_utils.py#L219
69 if not model_info["lib_name"] == "Sequential":
70 message = "Only Keras models with Sequential architecture are supported at this time. "
71 message += "Using Keras with other architechtures has undefined behavior."
72 global_logger.warning(message)
74 valid_final_layer = (
75 isinstance(model_obj.layers[-1], layers.Dense)
76 and model_obj.layers[-1].activation.__name__ == "softmax"
77 )
78 valid_final_layer = valid_final_layer or (
79 isinstance(model_obj.layers[-1], layers.Dense)
80 and model_obj.layers[-1].activation.__name__ == "sigmoid"
81 )
82 valid_final_layer = valid_final_layer or isinstance(
83 model_obj.layers[-1], layers.Softmax
84 )
85 if not valid_final_layer:
86 message = "Expected output layer to be either: tf.keras.layers.Softmax or "
87 message += "tf.keras.layers.Dense with softmax or sigmoid activation."
88 global_logger.warning(message)
90 if len(model_obj.layers[-1].output.shape) != 2:
91 message = "Expected 2D output shape for Keras.Sequetial model: (batch_size, n_classes) or (None, n_classes)"
92 global_logger.warning(message)
94 if model_obj.layers[-1].output.shape[0] is not None:
95 message = "Expected output shape of Keras model to have arbitrary length"
96 global_logger.warning(message)
98 if (
99 model_obj.layers[-1].output.shape[1] < 2
100 and model_obj.layers[-1].activation.__name__ != "sigmoid"
101 ):
102 message = "Expected classification output shape (batch_size, n_classes) or (None, n_classes). "
103 message += "Univariate outputs not supported at this time."
104 global_logger.warning(message)
106 if (
107 model_obj.layers[-1].output.shape[1] > 2
108 and model_obj.layers[-1].activation.__name__ != "softmax"
109 and not isinstance(model_obj.layers[-1], layers.Softmax)
110 ):
111 message = "Expected multiclass classification to use softmax activation with "
112 message += "output shape (batch_size, n_classes) or (None, n_classes). "
113 message += "Non-softmax classification not supported at this time."
114 global_logger.warning(message)
115 # TODO Add support for model-imposed argmax layer
116 # https://stackoverflow.com/questions/56704669/keras-output-single-value-through-argmax
119def validate_dummy(model_like, _):
120 if model_like.model_like:
121 tmp_model_info = get_model_info(model_like.model_like)
122 if tmp_model_info["framework"] == "keras":
123 validate_keras_clf(model_like.model_like, tmp_model_info)
124 elif tmp_model_info["framework"] in ("sklearn", "xgboost"):
125 validate_sklearn_like(model_like.model_like, tmp_model_info)
126 else:
127 raise