Coverage for credoai/lens/lens_validation.py: 52%
75 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-13 21:56 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-13 21:56 +0000
1import numpy as np
2import pandas as pd
4try:
5 tf_exists = True
6 import tensorflow as tf
7except ImportError:
8 tf_exists = False
10import inspect
12from credoai.utils import global_logger
13from credoai.utils.common import ValidationError
15###############################################
16# Checking artifact interactions (model + data)
17###############################################
20def check_model_data_consistency(model, data):
21 """
22 This validation function serves to check the compatibility of a model and dataset provided to Lens.
23 For each outputting function (e.g., `predict`) supported by Lens, this validator checks
24 whether the model supports that function. If so, the validator applies the outputting function to
25 a small sample of the supplied dataset. The validator ensures the outputting function does not fail
26 and, depending on the nature of the outputting function, performs light checking to verify the outputs
27 (e.g. predictions) match the expected form, possibly including: data type and output shape.
29 Parameters
30 ----------
31 model : artifacts.Model or a subtype of artifacts.Model
32 A trained machine learning model wrapped as a Lens Model object
33 data : artifacts.Data or a subtype of artifacts.Data
34 The dataset that will be assessed by Lens evaluators, wrapped as a Lens Data object
35 """
36 # check predict
37 # Keras always outputs numpy types (not Tensor or something else)
38 if "predict" in model.__dict__.keys() and data.y is not None and data.X is not None:
39 try:
40 mini_pred, batch_size = check_prediction_model_output(model.predict, data)
41 if not mini_pred.size:
42 # results for all presently supported models are ndarray results
43 raise Exception("Empty return results from predict function.")
44 if len(mini_pred.shape) > 1:
45 # check that output size per sample matches up
46 if isinstance(data.y, np.ndarray) and (
47 mini_pred.shape[1:] != data.y.shape[1:]
48 ):
49 raise Exception("Predictions have mismatched shape from provided y")
50 elif isinstance(data.y, pd.Series) and (
51 mini_pred.shape[1:] != data.y.head(batch_size).shape[1:]
52 ):
53 raise Exception("Predictions have mismatched shape from provided y")
54 except Exception as e:
55 raise ValidationError(
56 "Lens.model predictions do not match expected form implied by provided labels y.",
57 e,
58 )
60 if (
61 "predict_proba" in model.__dict__.keys()
62 and data.y is not None
63 and data.X is not None
64 ):
65 try:
66 mini_pred, batch_size = check_prediction_model_output(
67 model.predict_proba, data
68 )
69 if not mini_pred.size:
70 # results for all presently supported models are ndarray results
71 raise Exception("Empty return results from predict_proba function.")
72 if len(mini_pred.shape) > 1 and mini_pred.shape[1] > 1:
73 if np.sum(mini_pred[0]) != 1:
74 raise Exception(
75 "`predict_proba` outputs invalid. Per-sample outputs should sum to 1."
76 )
77 else:
78 if mini_pred[0] >= 1:
79 raise Exception(
80 "`predict_proba` outputs invalid. Binary outputs should be <= 1."
81 )
82 except Exception as e:
83 raise ValidationError(
84 "Lens.model outputs do not match expected form implied by provided labels y.",
85 e,
86 )
88 if "compare" in model.__dict__.keys() and data.pairs is not None:
89 try:
90 comps, batch_size = check_comparison_model_output(model.compare, data)
91 if type(comps) != list:
92 raise Exception(
93 "Comparison function expected to produce output of type list."
94 )
95 if not comps:
96 # results are expected to be a list
97 raise Exception("Empty return results from compare function.")
99 except Exception as e:
100 raise ValidationError(
101 "Lens.model outputs do not match expected form implied by provided labels y.",
102 e,
103 )
106def check_prediction_model_output(fn, data, batch_in: int = 1):
107 """
108 Helper function for prediction-type models. For use with `check_model_data_consistency`.
110 This helper does the work of actually obtaining predictions (from `predict` or `predict_proba`;
111 flexible enough for future use with functions that have similar behavior) and verifying that the
112 outputs are consistent with expected outputs specified by the ground truth `data.y`.
114 Parameters
115 ----------
116 fn : function object
117 The prediction-generating function for the model passed to `check_model_data_consistency`
118 data : artifacts.Data or a subtype of artifacts.Data
119 The dataset that will be assessed by Lens evaluators, wrapped as a Lens Data object
120 batch : an integer
121 The size of the sample prediction. We do not perform prediction on the entire `data.X` object
122 since this could be large and computationally expensive.
123 """
124 mini_pred = None
125 batch_out = batch_in
126 if isinstance(data.X, np.ndarray):
127 mini_pred = fn(np.reshape(data.X[0], (1, -1)))
128 elif isinstance(data.X, (pd.DataFrame, pd.Series)):
129 mini_pred = fn(data.X.head(batch_in))
130 elif tf_exists and isinstance(data.X, tf.Tensor):
131 mini_pred = fn(data.X)
132 elif (
133 tf_exists and isinstance(data.X, tf.data.Dataset)
134 ) or inspect.isgeneratorfunction(data.X):
135 one_batch = next(iter(data.X))
136 batch_out = len(one_batch)
137 if len(one_batch) >= 2:
138 # batch is tuple
139 # includes y and possibly weights; X is first
140 mini_pred = fn(one_batch[0])
141 else:
142 # batch only contains X
143 mini_pred = fn(one_batch)
144 elif tf_exists and isinstance(data.X, tf.keras.utils.Sequence):
145 mini_pred = fn(data.X.__getitem__(0))
146 batch_out = len(mini_pred)
147 else:
148 message = "Input X is of unsupported type. Behavior is undefined. Proceed with caution"
149 global_logger.warning(message)
150 mini_pred = fn(data.X[0])
152 return np.array(mini_pred), batch_out
155def check_comparison_model_output(fn, data, batch_in=1):
156 """
157 Helper function for comparison-type models. For use with `check_model_data_consistency`.
159 This helper does the work of actually obtaining comparisons (from `compare`; flexible enough
160 for future use with functions that have similar behavior) to verify the function does not fail.
162 Parameters
163 ----------
164 fn : function object
165 The comparison-generating function for the model passed to `check_model_data_consistency`
166 data : artifacts.Data or a subtype of artifacts.Data
167 The dataset that will be assessed by Lens evaluators, wrapped as a Lens Data object
168 batch : an integer
169 The size of the sample prediction. We do not perform prediction on the entire `data.pairs`
170 object since this could be large and computationally expensive.
171 """
172 comps = None
173 batch_out = batch_in
174 if isinstance(data.pairs, pd.DataFrame):
175 # should always pass for ComparisonData, based on checks in that wrapper. Nevertheless...
176 comps = fn(data.pairs.head(batch_in))
177 else:
178 message = "Input pairs are of unsupported type. Behavior is undefined. Proceed with caution"
179 global_logger.warning(message)
180 comps = fn(data.pairs[:batch_in])
182 return comps, batch_out