Coverage for credoai/evaluators/model_profiler.py: 69%

102 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-13 21:56 +0000

1import re 

2from typing import Optional 

3 

4import numpy as np 

5from connect.evidence import TableContainer 

6from connect.evidence.lens_evidence import ModelProfilerContainer 

7from pandas import DataFrame, concat 

8 

9from credoai.evaluators.evaluator import Evaluator 

10from credoai.evaluators.utils.validation import check_existence 

11from credoai.utils import ValidationError, global_logger 

12 

13USER_INFO_TEMPLATE = { 

14 "developed_by": None, 

15 "shared_by": None, 

16 "model_type": None, 

17 "intended_use": None, 

18 "downstream_use": None, 

19 "out_of_scope_use": None, 

20 "language": None, 

21 "related_models": None, 

22 "license": None, 

23 "resources_for_more_info": None, 

24 "input_description": None, 

25 "output_description": None, 

26 "performance_evaluated_on": None, 

27 "limitations": None, 

28} 

29 

30PROTECTED_KEYS = [ 

31 "model_name", 

32 "python_model_type", 

33 "library", 

34 "model_library", 

35 "feature_names", 

36 "parameters", 

37 "data_sample", 

38] 

39 

40 

41class ModelProfiler(Evaluator): 

42 """ 

43 Model profiling evaluator (Experimental) 

44 

45 This evaluator builds a model card the purpose of which is to characterize 

46 a fitted model. 

47 

48 The overall strategy is: 

49 

50 1. Extract all potentially useful info from the model itself in an 

51 automatic fashion. 

52 2. Allow the user to personalize the model card freely. 

53 

54 The method generate_template() provides a dictionary with several entries the 

55 user could be interested in filling up. 

56 

57 Parameters 

58 ---------- 

59 model_info : Optional[dict] 

60 Information provided by the user that cannot be inferred by 

61 the model itself. The dictionary con contain any number of elements, 

62 a template can be provided by running the generate_template() method. 

63 

64 The only restrictions are checked in a validation step: 

65 

66 1. Some keys are protected because they are used internally 

67 2. Only basic python types are accepted as values 

68 

69 """ 

70 

71 required_artifacts = {"model", "assessment_data"} 

72 

73 def __init__(self, model_info: Optional[dict] = None): 

74 super().__init__() 

75 self.usr_model_info = model_info 

76 if not self.usr_model_info: 

77 self.usr_model_info = {} 

78 self._validate_usr_model_info() 

79 self.logger = global_logger 

80 

81 def _setup(self): 

82 self.model_name = self.model.name 

83 self.model_internal = self.model.model_like 

84 self.model_type = type(self.model_internal) 

85 

86 def _validate_arguments(self): 

87 check_existence(self.model, "model") 

88 

89 def evaluate(self): 

90 basic = self._get_basic_info() 

91 res = self._get_model_params() 

92 if self.model.model_info["framework"] == "keras": 

93 self.results = [ 

94 TableContainer( 

95 self._generate_keras_results_table(res, basic), 

96 **self.get_info(), 

97 ) 

98 ] 

99 return self 

100 # Add user generated info 

101 self.usr_model_info = {k: v for k, v in self.usr_model_info.items() if v} 

102 # Get a sample of the data 

103 data_sample = self._get_dataset_sample() 

104 # Collate info 

105 res = {**basic, **res, **self.usr_model_info, **data_sample} 

106 # Format 

107 res, labels = self._add_entries_labeling(res) 

108 

109 # Package into evidence 

110 self.results = [ModelProfilerContainer(res, **self.get_info(labels=labels))] 

111 return self 

112 

113 @staticmethod 

114 def _generate_keras_results_table(res, basic): 

115 basic = DataFrame(basic, index=[0]).T 

116 opt_info = DataFrame(res["parameters"]["optimizer_info"], index=[0]) 

117 opt_info.columns = [f"optimizer.{x}" for x in opt_info.columns] 

118 opt_info = opt_info.T 

119 choose = ["total_parameters", "trainable_parameters"] 

120 chosen = DataFrame( 

121 {k: v for k, v in res["parameters"].items() if k in choose}, index=[0] 

122 ).T 

123 

124 output = concat([basic, chosen, opt_info]) 

125 output = output.reset_index() 

126 output.columns = ["parameters", "values"] 

127 output.name = "model profile" 

128 return output 

129 

130 def _get_basic_info(self) -> dict: 

131 """ 

132 Collect basic information directly from the model artifact. 

133 

134 Returns 

135 ------- 

136 dict 

137 Dictionary containing name, full class identifier 

138 """ 

139 return { 

140 "model_name": self.model_name, 

141 "python_model_type": str(self.model_type).split("'")[1], 

142 } 

143 

144 def _get_dataset_sample(self) -> dict: 

145 """ 

146 If assessment data is available get a sample of it. 

147 """ 

148 try: 

149 data_sample = { 

150 "data_sample": self.assessment_data.X.sample( 

151 3, random_state=42 

152 ).to_dict(orient="list") 

153 } 

154 return data_sample 

155 

156 except: 

157 message = "No data found -> a sample of the data won't be included in the model card" 

158 self.logger.info(message) 

159 return {} 

160 

161 def _get_model_params(self) -> dict: 

162 """ 

163 Select which parameter structure to utilize based on library/model used. 

164 

165 Returns 

166 ------- 

167 dict 

168 Dictionary of model info 

169 """ 

170 if "sklearn" in str(self.model_type): 

171 return self._get_sklearn_model_params() 

172 if "keras" in str(self.model_type): 

173 return self._get_keras_model_params() 

174 self.logger.info( 

175 "Automatic model parameter inference not available for this model type." 

176 ) 

177 return {} 

178 

179 def _get_keras_model_params(self) -> dict: 

180 trainable_parameters = int( 

181 np.sum( 

182 [np.prod(v.get_shape()) for v in self.model_internal.trainable_weights] 

183 ) 

184 ) 

185 non_trainable_parameters = int( 

186 np.sum( 

187 [ 

188 np.prod(v.get_shape()) 

189 for v in self.model_internal.non_trainable_weights 

190 ] 

191 ) 

192 ) 

193 

194 total_parameters = trainable_parameters + non_trainable_parameters 

195 

196 opt_info = self.model_internal.optimizer.get_config() # dict 

197 

198 network_structure = DataFrame( 

199 [ 

200 ( 

201 x.name, 

202 re.sub(r"[^a-zA-Z]", "", str(type(x)).split(".")[-1]), 

203 x.input_shape, 

204 x.output_shape, 

205 x.count_params(), 

206 ) 

207 for x in self.model_internal.layers 

208 ], 

209 columns=["name", "layer_type", "input_shape", "output_shape", "parameters"], 

210 ) 

211 

212 return { 

213 "parameters": { 

214 "total_parameters": total_parameters, 

215 "trainable_parameters": trainable_parameters, 

216 "non_trainable_parameters": non_trainable_parameters, 

217 "network_architecture": network_structure, 

218 "optimizer_info": opt_info, 

219 } 

220 } 

221 

222 def _get_sklearn_model_params(self) -> dict: 

223 """ 

224 Get info from sklearn like models 

225 

226 Returns 

227 ------- 

228 dict 

229 Dictionary of info about the model 

230 """ 

231 parameters = self.model_internal.get_params() 

232 model_library = self.model_type.__name__ 

233 library = "sklearn" 

234 if hasattr(self.model_internal, "feature_names_in_"): 

235 feature_names = list(self.model_internal.feature_names_in_) 

236 else: 

237 feature_names = None 

238 return { 

239 "library": library, 

240 "model_library": model_library, 

241 "parameters": parameters, 

242 "feature_names": feature_names, 

243 } 

244 

245 def _validate_usr_model_info(self): 

246 """ 

247 Validate information that the user has inputted manually. 

248 

249 Any key is valid unless it's already in use internally. 

250 

251 """ 

252 protected = [k for k in self.usr_model_info.keys() if k in PROTECTED_KEYS] 

253 if protected: 

254 message = f"Found {protected} in model_info.keys(), these keys are already in use. Please rename/remove them." 

255 raise ValidationError(message) 

256 

257 accepted_formats = (list, int, float, dict, str) 

258 non_accepted = [ 

259 k 

260 for k, v in self.usr_model_info.items() 

261 if not isinstance(v, accepted_formats) and v is not None 

262 ] 

263 if non_accepted: 

264 message = f"The items {non_accepted} in model info are not of types: (list, int, float, dict, str)" 

265 raise ValidationError(message) 

266 

267 @staticmethod 

268 def _add_entries_labeling(results: dict) -> tuple: 

269 """ 

270 Takes the combined entries and format + create label to distinguish 

271 user generated ones. 

272 

273 Parameters 

274 ---------- 

275 results : dict 

276 Dictionary of all the entries 

277 

278 Returns 

279 ------- 

280 tuple 

281 DataFrame, dict 

282 """ 

283 res = DataFrame.from_dict(results, orient="index") 

284 res.columns = ["results"] 

285 labels = {"user_generated": list(res.index[~res.index.isin(PROTECTED_KEYS)])} 

286 return res, labels 

287 

288 @staticmethod 

289 def generate_template() -> dict: 

290 """ 

291 Passes a template for model related info that the user 

292 can populate and customize. 

293 

294 Loosely based on: 

295 https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/templates/modelcard_template.md#model-details 

296 https://modelcards.withgoogle.com/model-reports 

297 

298 Returns 

299 ------- 

300 dict 

301 Dictionary of keys working as bookmarks for the user info 

302 """ 

303 return USER_INFO_TEMPLATE