Coverage for credoai/evaluators/model_profiler.py: 84%

74 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-08 07:32 +0000

1from typing import Optional 

2 

3from connect.evidence.lens_evidence import ModelProfilerContainer 

4from credoai.evaluators import Evaluator 

5from credoai.evaluators.utils.validation import check_existence 

6from credoai.utils import ValidationError, global_logger 

7from pandas import DataFrame 

8 

9USER_INFO_TEMPLATE = { 

10 "developed_by": None, 

11 "shared_by": None, 

12 "model_type": None, 

13 "intended_use": None, 

14 "downstream_use": None, 

15 "out_of_scope_use": None, 

16 "language": None, 

17 "related_models": None, 

18 "license": None, 

19 "resources_for_more_info": None, 

20 "input_description": None, 

21 "output_description": None, 

22 "performance_evaluated_on": None, 

23 "limitations": None, 

24} 

25 

26PROTECTED_KEYS = [ 

27 "model_name", 

28 "python_model_type", 

29 "library", 

30 "model_library", 

31 "feature_names", 

32 "parameters", 

33 "data_sample", 

34] 

35 

36 

37class ModelProfiler(Evaluator): 

38 """ 

39 Model profiling evaluator. 

40 

41 This evaluator builds a model card the purpose of which is to characterize 

42 a fitted model. 

43 

44 The overall strategy is: 

45 

46 1. Extract all potentially useful info from the model itself in an 

47 automatic fashion. 

48 2. Allow the user to personalize the model card freely. 

49 

50 The method generate_template() provides a dictionary with several entries the 

51 user could be interested in filling up. 

52 

53 Parameters 

54 ---------- 

55 model_info : Optional[dict] 

56 Information provided by the user that cannot be inferred by 

57 the model itself. The dictionary con contain any number of elements, 

58 a template can be provided by running the generate_template() method. 

59 

60 The only restrictions are checked in a validation step: 

61 

62 1. Some keys are protected because they are used internally 

63 2. Only basic python types are accepted as values 

64 

65 """ 

66 

67 required_artifacts = {"model", "assessment_data"} 

68 

69 def __init__(self, model_info: Optional[dict] = None): 

70 super().__init__() 

71 self.usr_model_info = model_info 

72 if not self.usr_model_info: 

73 self.usr_model_info = {} 

74 self._validate_usr_model_info() 

75 self.logger = global_logger 

76 

77 def _setup(self): 

78 self.model_name = self.model.name 

79 self.model = self.model.model_like 

80 self.model_type = type(self.model) 

81 

82 def _validate_arguments(self): 

83 check_existence(self.model, "model") 

84 

85 def evaluate(self): 

86 basic = self._get_basic_info() 

87 res = self._get_model_params() 

88 # Add user generated info 

89 self.usr_model_info = {k: v for k, v in self.usr_model_info.items() if v} 

90 # Get a sample of the data 

91 data_sample = self._get_dataset_sample() 

92 # Collate info 

93 res = {**basic, **res, **self.usr_model_info, **data_sample} 

94 # Format 

95 res, labels = self._add_entries_labeling(res) 

96 # Package into evidence 

97 self.results = [ 

98 ModelProfilerContainer(res, **self.get_container_info(labels=labels)) 

99 ] 

100 return self 

101 

102 def _get_basic_info(self) -> dict: 

103 """ 

104 Collect basic information directly from the model artifact. 

105 

106 Returns 

107 ------- 

108 dict 

109 Dictionary containing name, full class identifier 

110 """ 

111 return { 

112 "model_name": self.model_name, 

113 "python_model_type": str(self.model_type).split("'")[1], 

114 } 

115 

116 def _get_dataset_sample(self) -> dict: 

117 """ 

118 If assessment data is available get a sample of it. 

119 """ 

120 try: 

121 data_sample = { 

122 "data_sample": self.assessment_data.X.sample( 

123 3, random_state=42 

124 ).to_dict(orient="list") 

125 } 

126 return data_sample 

127 

128 except: 

129 message = "No data found -> a sample of the data won't be included in the model card" 

130 self.logger.info(message) 

131 return {} 

132 

133 def _get_model_params(self) -> dict: 

134 """ 

135 Select which parameter structure to utilize based on library/model used. 

136 

137 Returns 

138 ------- 

139 dict 

140 Dictionary of model info 

141 """ 

142 if "sklearn" in str(self.model_type): 

143 return self._get_sklearn_model_params() 

144 else: 

145 self.logger.info( 

146 "Automatic model parameter inference not available for this model type." 

147 ) 

148 return {} 

149 

150 def _get_sklearn_model_params(self) -> dict: 

151 """ 

152 Get info from sklearn like models 

153 

154 Returns 

155 ------- 

156 dict 

157 Dictionary of info about the model 

158 """ 

159 parameters = self.model.get_params() 

160 model_library = self.model_type.__name__ 

161 library = "sklearn" 

162 if hasattr(self.model, "feature_names_in_"): 

163 feature_names = list(self.model.feature_names_in_) 

164 else: 

165 feature_names = None 

166 return { 

167 "library": library, 

168 "model_library": model_library, 

169 "parameters": parameters, 

170 "feature_names": feature_names, 

171 } 

172 

173 def _validate_usr_model_info(self): 

174 """ 

175 Validate information that the user has inputted manually. 

176 

177 Any key is valid unless it's already in use internally. 

178 

179 """ 

180 protected = [k for k in self.usr_model_info.keys() if k in PROTECTED_KEYS] 

181 if protected: 

182 message = f"Found {protected} in model_info.keys(), these keys are already in use. Please rename/remove them." 

183 raise ValidationError(message) 

184 

185 accepted_formats = (list, int, float, dict, str) 

186 non_accepted = [ 

187 k 

188 for k, v in self.usr_model_info.items() 

189 if not isinstance(v, accepted_formats) and v is not None 

190 ] 

191 if non_accepted: 

192 message = f"The items {non_accepted} in model info are not of types: (list, int, float, dict, str)" 

193 raise ValidationError(message) 

194 

195 @staticmethod 

196 def _add_entries_labeling(results: dict) -> tuple: 

197 """ 

198 Takes the combined entries and format + create label to distinguish 

199 user generated ones. 

200 

201 Parameters 

202 ---------- 

203 results : dict 

204 Dictionary of all the entries 

205 

206 Returns 

207 ------- 

208 tuple 

209 DataFrame, dict 

210 """ 

211 res = DataFrame.from_dict(results, orient="index") 

212 res.columns = ["results"] 

213 labels = {"user_generated": list(res.index[~res.index.isin(PROTECTED_KEYS)])} 

214 return res, labels 

215 

216 @staticmethod 

217 def generate_template() -> dict: 

218 """ 

219 Passes a template for model related info that the user 

220 can populate and customize. 

221 

222 Loosely based on: 

223 https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/templates/modelcard_template.md#model-details 

224 https://modelcards.withgoogle.com/model-reports 

225 

226 Returns 

227 ------- 

228 dict 

229 Dictionary of keys working as bookmarks for the user info 

230 """ 

231 return USER_INFO_TEMPLATE