Coverage for credoai/evaluators/model_profiler.py: 69%
102 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-13 21:56 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-13 21:56 +0000
1import re
2from typing import Optional
4import numpy as np
5from connect.evidence import TableContainer
6from connect.evidence.lens_evidence import ModelProfilerContainer
7from pandas import DataFrame, concat
9from credoai.evaluators.evaluator import Evaluator
10from credoai.evaluators.utils.validation import check_existence
11from credoai.utils import ValidationError, global_logger
13USER_INFO_TEMPLATE = {
14 "developed_by": None,
15 "shared_by": None,
16 "model_type": None,
17 "intended_use": None,
18 "downstream_use": None,
19 "out_of_scope_use": None,
20 "language": None,
21 "related_models": None,
22 "license": None,
23 "resources_for_more_info": None,
24 "input_description": None,
25 "output_description": None,
26 "performance_evaluated_on": None,
27 "limitations": None,
28}
30PROTECTED_KEYS = [
31 "model_name",
32 "python_model_type",
33 "library",
34 "model_library",
35 "feature_names",
36 "parameters",
37 "data_sample",
38]
41class ModelProfiler(Evaluator):
42 """
43 Model profiling evaluator (Experimental)
45 This evaluator builds a model card the purpose of which is to characterize
46 a fitted model.
48 The overall strategy is:
50 1. Extract all potentially useful info from the model itself in an
51 automatic fashion.
52 2. Allow the user to personalize the model card freely.
54 The method generate_template() provides a dictionary with several entries the
55 user could be interested in filling up.
57 Parameters
58 ----------
59 model_info : Optional[dict]
60 Information provided by the user that cannot be inferred by
61 the model itself. The dictionary con contain any number of elements,
62 a template can be provided by running the generate_template() method.
64 The only restrictions are checked in a validation step:
66 1. Some keys are protected because they are used internally
67 2. Only basic python types are accepted as values
69 """
71 required_artifacts = {"model", "assessment_data"}
73 def __init__(self, model_info: Optional[dict] = None):
74 super().__init__()
75 self.usr_model_info = model_info
76 if not self.usr_model_info:
77 self.usr_model_info = {}
78 self._validate_usr_model_info()
79 self.logger = global_logger
81 def _setup(self):
82 self.model_name = self.model.name
83 self.model_internal = self.model.model_like
84 self.model_type = type(self.model_internal)
86 def _validate_arguments(self):
87 check_existence(self.model, "model")
89 def evaluate(self):
90 basic = self._get_basic_info()
91 res = self._get_model_params()
92 if self.model.model_info["framework"] == "keras":
93 self.results = [
94 TableContainer(
95 self._generate_keras_results_table(res, basic),
96 **self.get_info(),
97 )
98 ]
99 return self
100 # Add user generated info
101 self.usr_model_info = {k: v for k, v in self.usr_model_info.items() if v}
102 # Get a sample of the data
103 data_sample = self._get_dataset_sample()
104 # Collate info
105 res = {**basic, **res, **self.usr_model_info, **data_sample}
106 # Format
107 res, labels = self._add_entries_labeling(res)
109 # Package into evidence
110 self.results = [ModelProfilerContainer(res, **self.get_info(labels=labels))]
111 return self
113 @staticmethod
114 def _generate_keras_results_table(res, basic):
115 basic = DataFrame(basic, index=[0]).T
116 opt_info = DataFrame(res["parameters"]["optimizer_info"], index=[0])
117 opt_info.columns = [f"optimizer.{x}" for x in opt_info.columns]
118 opt_info = opt_info.T
119 choose = ["total_parameters", "trainable_parameters"]
120 chosen = DataFrame(
121 {k: v for k, v in res["parameters"].items() if k in choose}, index=[0]
122 ).T
124 output = concat([basic, chosen, opt_info])
125 output = output.reset_index()
126 output.columns = ["parameters", "values"]
127 output.name = "model profile"
128 return output
130 def _get_basic_info(self) -> dict:
131 """
132 Collect basic information directly from the model artifact.
134 Returns
135 -------
136 dict
137 Dictionary containing name, full class identifier
138 """
139 return {
140 "model_name": self.model_name,
141 "python_model_type": str(self.model_type).split("'")[1],
142 }
144 def _get_dataset_sample(self) -> dict:
145 """
146 If assessment data is available get a sample of it.
147 """
148 try:
149 data_sample = {
150 "data_sample": self.assessment_data.X.sample(
151 3, random_state=42
152 ).to_dict(orient="list")
153 }
154 return data_sample
156 except:
157 message = "No data found -> a sample of the data won't be included in the model card"
158 self.logger.info(message)
159 return {}
161 def _get_model_params(self) -> dict:
162 """
163 Select which parameter structure to utilize based on library/model used.
165 Returns
166 -------
167 dict
168 Dictionary of model info
169 """
170 if "sklearn" in str(self.model_type):
171 return self._get_sklearn_model_params()
172 if "keras" in str(self.model_type):
173 return self._get_keras_model_params()
174 self.logger.info(
175 "Automatic model parameter inference not available for this model type."
176 )
177 return {}
179 def _get_keras_model_params(self) -> dict:
180 trainable_parameters = int(
181 np.sum(
182 [np.prod(v.get_shape()) for v in self.model_internal.trainable_weights]
183 )
184 )
185 non_trainable_parameters = int(
186 np.sum(
187 [
188 np.prod(v.get_shape())
189 for v in self.model_internal.non_trainable_weights
190 ]
191 )
192 )
194 total_parameters = trainable_parameters + non_trainable_parameters
196 opt_info = self.model_internal.optimizer.get_config() # dict
198 network_structure = DataFrame(
199 [
200 (
201 x.name,
202 re.sub(r"[^a-zA-Z]", "", str(type(x)).split(".")[-1]),
203 x.input_shape,
204 x.output_shape,
205 x.count_params(),
206 )
207 for x in self.model_internal.layers
208 ],
209 columns=["name", "layer_type", "input_shape", "output_shape", "parameters"],
210 )
212 return {
213 "parameters": {
214 "total_parameters": total_parameters,
215 "trainable_parameters": trainable_parameters,
216 "non_trainable_parameters": non_trainable_parameters,
217 "network_architecture": network_structure,
218 "optimizer_info": opt_info,
219 }
220 }
222 def _get_sklearn_model_params(self) -> dict:
223 """
224 Get info from sklearn like models
226 Returns
227 -------
228 dict
229 Dictionary of info about the model
230 """
231 parameters = self.model_internal.get_params()
232 model_library = self.model_type.__name__
233 library = "sklearn"
234 if hasattr(self.model_internal, "feature_names_in_"):
235 feature_names = list(self.model_internal.feature_names_in_)
236 else:
237 feature_names = None
238 return {
239 "library": library,
240 "model_library": model_library,
241 "parameters": parameters,
242 "feature_names": feature_names,
243 }
245 def _validate_usr_model_info(self):
246 """
247 Validate information that the user has inputted manually.
249 Any key is valid unless it's already in use internally.
251 """
252 protected = [k for k in self.usr_model_info.keys() if k in PROTECTED_KEYS]
253 if protected:
254 message = f"Found {protected} in model_info.keys(), these keys are already in use. Please rename/remove them."
255 raise ValidationError(message)
257 accepted_formats = (list, int, float, dict, str)
258 non_accepted = [
259 k
260 for k, v in self.usr_model_info.items()
261 if not isinstance(v, accepted_formats) and v is not None
262 ]
263 if non_accepted:
264 message = f"The items {non_accepted} in model info are not of types: (list, int, float, dict, str)"
265 raise ValidationError(message)
267 @staticmethod
268 def _add_entries_labeling(results: dict) -> tuple:
269 """
270 Takes the combined entries and format + create label to distinguish
271 user generated ones.
273 Parameters
274 ----------
275 results : dict
276 Dictionary of all the entries
278 Returns
279 -------
280 tuple
281 DataFrame, dict
282 """
283 res = DataFrame.from_dict(results, orient="index")
284 res.columns = ["results"]
285 labels = {"user_generated": list(res.index[~res.index.isin(PROTECTED_KEYS)])}
286 return res, labels
288 @staticmethod
289 def generate_template() -> dict:
290 """
291 Passes a template for model related info that the user
292 can populate and customize.
294 Loosely based on:
295 https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/templates/modelcard_template.md#model-details
296 https://modelcards.withgoogle.com/model-reports
298 Returns
299 -------
300 dict
301 Dictionary of keys working as bookmarks for the user info
302 """
303 return USER_INFO_TEMPLATE