Coverage for credoai/lens/lens_validation.py: 52%

75 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-13 21:56 +0000

1import numpy as np 

2import pandas as pd 

3 

4try: 

5 tf_exists = True 

6 import tensorflow as tf 

7except ImportError: 

8 tf_exists = False 

9 

10import inspect 

11 

12from credoai.utils import global_logger 

13from credoai.utils.common import ValidationError 

14 

15############################################### 

16# Checking artifact interactions (model + data) 

17############################################### 

18 

19 

20def check_model_data_consistency(model, data): 

21 """ 

22 This validation function serves to check the compatibility of a model and dataset provided to Lens. 

23 For each outputting function (e.g., `predict`) supported by Lens, this validator checks 

24 whether the model supports that function. If so, the validator applies the outputting function to 

25 a small sample of the supplied dataset. The validator ensures the outputting function does not fail 

26 and, depending on the nature of the outputting function, performs light checking to verify the outputs 

27 (e.g. predictions) match the expected form, possibly including: data type and output shape. 

28 

29 Parameters 

30 ---------- 

31 model : artifacts.Model or a subtype of artifacts.Model 

32 A trained machine learning model wrapped as a Lens Model object 

33 data : artifacts.Data or a subtype of artifacts.Data 

34 The dataset that will be assessed by Lens evaluators, wrapped as a Lens Data object 

35 """ 

36 # check predict 

37 # Keras always outputs numpy types (not Tensor or something else) 

38 if "predict" in model.__dict__.keys() and data.y is not None and data.X is not None: 

39 try: 

40 mini_pred, batch_size = check_prediction_model_output(model.predict, data) 

41 if not mini_pred.size: 

42 # results for all presently supported models are ndarray results 

43 raise Exception("Empty return results from predict function.") 

44 if len(mini_pred.shape) > 1: 

45 # check that output size per sample matches up 

46 if isinstance(data.y, np.ndarray) and ( 

47 mini_pred.shape[1:] != data.y.shape[1:] 

48 ): 

49 raise Exception("Predictions have mismatched shape from provided y") 

50 elif isinstance(data.y, pd.Series) and ( 

51 mini_pred.shape[1:] != data.y.head(batch_size).shape[1:] 

52 ): 

53 raise Exception("Predictions have mismatched shape from provided y") 

54 except Exception as e: 

55 raise ValidationError( 

56 "Lens.model predictions do not match expected form implied by provided labels y.", 

57 e, 

58 ) 

59 

60 if ( 

61 "predict_proba" in model.__dict__.keys() 

62 and data.y is not None 

63 and data.X is not None 

64 ): 

65 try: 

66 mini_pred, batch_size = check_prediction_model_output( 

67 model.predict_proba, data 

68 ) 

69 if not mini_pred.size: 

70 # results for all presently supported models are ndarray results 

71 raise Exception("Empty return results from predict_proba function.") 

72 if len(mini_pred.shape) > 1 and mini_pred.shape[1] > 1: 

73 if np.sum(mini_pred[0]) != 1: 

74 raise Exception( 

75 "`predict_proba` outputs invalid. Per-sample outputs should sum to 1." 

76 ) 

77 else: 

78 if mini_pred[0] >= 1: 

79 raise Exception( 

80 "`predict_proba` outputs invalid. Binary outputs should be <= 1." 

81 ) 

82 except Exception as e: 

83 raise ValidationError( 

84 "Lens.model outputs do not match expected form implied by provided labels y.", 

85 e, 

86 ) 

87 

88 if "compare" in model.__dict__.keys() and data.pairs is not None: 

89 try: 

90 comps, batch_size = check_comparison_model_output(model.compare, data) 

91 if type(comps) != list: 

92 raise Exception( 

93 "Comparison function expected to produce output of type list." 

94 ) 

95 if not comps: 

96 # results are expected to be a list 

97 raise Exception("Empty return results from compare function.") 

98 

99 except Exception as e: 

100 raise ValidationError( 

101 "Lens.model outputs do not match expected form implied by provided labels y.", 

102 e, 

103 ) 

104 

105 

106def check_prediction_model_output(fn, data, batch_in: int = 1): 

107 """ 

108 Helper function for prediction-type models. For use with `check_model_data_consistency`. 

109 

110 This helper does the work of actually obtaining predictions (from `predict` or `predict_proba`; 

111 flexible enough for future use with functions that have similar behavior) and verifying that the 

112 outputs are consistent with expected outputs specified by the ground truth `data.y`. 

113 

114 Parameters 

115 ---------- 

116 fn : function object 

117 The prediction-generating function for the model passed to `check_model_data_consistency` 

118 data : artifacts.Data or a subtype of artifacts.Data 

119 The dataset that will be assessed by Lens evaluators, wrapped as a Lens Data object 

120 batch : an integer 

121 The size of the sample prediction. We do not perform prediction on the entire `data.X` object 

122 since this could be large and computationally expensive. 

123 """ 

124 mini_pred = None 

125 batch_out = batch_in 

126 if isinstance(data.X, np.ndarray): 

127 mini_pred = fn(np.reshape(data.X[0], (1, -1))) 

128 elif isinstance(data.X, (pd.DataFrame, pd.Series)): 

129 mini_pred = fn(data.X.head(batch_in)) 

130 elif tf_exists and isinstance(data.X, tf.Tensor): 

131 mini_pred = fn(data.X) 

132 elif ( 

133 tf_exists and isinstance(data.X, tf.data.Dataset) 

134 ) or inspect.isgeneratorfunction(data.X): 

135 one_batch = next(iter(data.X)) 

136 batch_out = len(one_batch) 

137 if len(one_batch) >= 2: 

138 # batch is tuple 

139 # includes y and possibly weights; X is first 

140 mini_pred = fn(one_batch[0]) 

141 else: 

142 # batch only contains X 

143 mini_pred = fn(one_batch) 

144 elif tf_exists and isinstance(data.X, tf.keras.utils.Sequence): 

145 mini_pred = fn(data.X.__getitem__(0)) 

146 batch_out = len(mini_pred) 

147 else: 

148 message = "Input X is of unsupported type. Behavior is undefined. Proceed with caution" 

149 global_logger.warning(message) 

150 mini_pred = fn(data.X[0]) 

151 

152 return np.array(mini_pred), batch_out 

153 

154 

155def check_comparison_model_output(fn, data, batch_in=1): 

156 """ 

157 Helper function for comparison-type models. For use with `check_model_data_consistency`. 

158 

159 This helper does the work of actually obtaining comparisons (from `compare`; flexible enough 

160 for future use with functions that have similar behavior) to verify the function does not fail. 

161 

162 Parameters 

163 ---------- 

164 fn : function object 

165 The comparison-generating function for the model passed to `check_model_data_consistency` 

166 data : artifacts.Data or a subtype of artifacts.Data 

167 The dataset that will be assessed by Lens evaluators, wrapped as a Lens Data object 

168 batch : an integer 

169 The size of the sample prediction. We do not perform prediction on the entire `data.pairs` 

170 object since this could be large and computationally expensive. 

171 """ 

172 comps = None 

173 batch_out = batch_in 

174 if isinstance(data.pairs, pd.DataFrame): 

175 # should always pass for ComparisonData, based on checks in that wrapper. Nevertheless... 

176 comps = fn(data.pairs.head(batch_in)) 

177 else: 

178 message = "Input pairs are of unsupported type. Behavior is undefined. Proceed with caution" 

179 global_logger.warning(message) 

180 comps = fn(data.pairs[:batch_in]) 

181 

182 return comps, batch_out