Coverage for credoai/utils/model_utils.py: 38%

78 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-13 21:56 +0000

1import warnings 

2 

3from sklearn.base import is_classifier, is_regressor 

4from sklearn.ensemble import RandomForestClassifier 

5from sklearn.utils import multiclass 

6 

7from credoai.utils import global_logger 

8 

9try: 

10 from tensorflow.keras import layers 

11except ImportError: 

12 pass 

13 

14 

15def get_generic_classifier(): 

16 with warnings.catch_warnings(): 

17 warnings.simplefilter(action="ignore", category=FutureWarning) 

18 try: 

19 import xgboost as xgb 

20 

21 try: 

22 model = xgb.XGBClassifier( 

23 use_label_encoder=False, eval_metric="logloss" 

24 ) 

25 except xgb.core.XGBoostError: 

26 model = RandomForestClassifier() 

27 except ModuleNotFoundError: 

28 model = RandomForestClassifier() 

29 return model 

30 

31 

32def get_model_info(model): 

33 """Returns basic information about model info""" 

34 try: 

35 framework = getattr(model, "framework_like", None) 

36 if not framework: 

37 framework = model.__class__.__module__.split(".")[0] 

38 except AttributeError: 

39 framework = None 

40 try: 

41 name = model.__class__.__name__ 

42 except AttributeError: 

43 name = None 

44 return {"framework": framework, "lib_name": name} 

45 

46 

47def get_default_metrics(model): 

48 if is_classifier(model): 

49 return ["accuracy_score"] 

50 elif is_regressor(model): 

51 return ["r2_score"] 

52 else: 

53 return None 

54 

55 

56def type_of_target(target): 

57 return multiclass.type_of_target(target) if target is not None else None 

58 

59 

60############################################# 

61# Validation Functions for Various Model Types 

62############################################# 

63def validate_sklearn_like(model_obj, model_info: dict): 

64 pass 

65 

66 

67def validate_keras_clf(model_obj, model_info: dict): 

68 # This is how Keras checks sequential too: https://github.com/keras-team/keras/blob/master/keras/utils/layer_utils.py#L219 

69 if not model_info["lib_name"] == "Sequential": 

70 message = "Only Keras models with Sequential architecture are supported at this time. " 

71 message += "Using Keras with other architechtures has undefined behavior." 

72 global_logger.warning(message) 

73 

74 valid_final_layer = ( 

75 isinstance(model_obj.layers[-1], layers.Dense) 

76 and model_obj.layers[-1].activation.__name__ == "softmax" 

77 ) 

78 valid_final_layer = valid_final_layer or ( 

79 isinstance(model_obj.layers[-1], layers.Dense) 

80 and model_obj.layers[-1].activation.__name__ == "sigmoid" 

81 ) 

82 valid_final_layer = valid_final_layer or isinstance( 

83 model_obj.layers[-1], layers.Softmax 

84 ) 

85 if not valid_final_layer: 

86 message = "Expected output layer to be either: tf.keras.layers.Softmax or " 

87 message += "tf.keras.layers.Dense with softmax or sigmoid activation." 

88 global_logger.warning(message) 

89 

90 if len(model_obj.layers[-1].output.shape) != 2: 

91 message = "Expected 2D output shape for Keras.Sequetial model: (batch_size, n_classes) or (None, n_classes)" 

92 global_logger.warning(message) 

93 

94 if model_obj.layers[-1].output.shape[0] is not None: 

95 message = "Expected output shape of Keras model to have arbitrary length" 

96 global_logger.warning(message) 

97 

98 if ( 

99 model_obj.layers[-1].output.shape[1] < 2 

100 and model_obj.layers[-1].activation.__name__ != "sigmoid" 

101 ): 

102 message = "Expected classification output shape (batch_size, n_classes) or (None, n_classes). " 

103 message += "Univariate outputs not supported at this time." 

104 global_logger.warning(message) 

105 

106 if ( 

107 model_obj.layers[-1].output.shape[1] > 2 

108 and model_obj.layers[-1].activation.__name__ != "softmax" 

109 and not isinstance(model_obj.layers[-1], layers.Softmax) 

110 ): 

111 message = "Expected multiclass classification to use softmax activation with " 

112 message += "output shape (batch_size, n_classes) or (None, n_classes). " 

113 message += "Non-softmax classification not supported at this time." 

114 global_logger.warning(message) 

115 # TODO Add support for model-imposed argmax layer 

116 # https://stackoverflow.com/questions/56704669/keras-output-single-value-through-argmax 

117 

118 

119def validate_dummy(model_like, _): 

120 if model_like.model_like: 

121 tmp_model_info = get_model_info(model_like.model_like) 

122 if tmp_model_info["framework"] == "keras": 

123 validate_keras_clf(model_like.model_like, tmp_model_info) 

124 elif tmp_model_info["framework"] in ("sklearn", "xgboost"): 

125 validate_sklearn_like(model_like.model_like, tmp_model_info) 

126 else: 

127 raise