Coverage for credoai/evaluators/shap.py: 84%

87 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-08 07:32 +0000

1import enum 

2from typing import Dict, List, Optional, Union 

3 

4import numpy as np 

5import pandas as pd 

6from shap import Explainer, Explanation, kmeans 

7 

8from credoai.evaluators import Evaluator 

9from credoai.evaluators.utils.validation import check_requirements_existence 

10from connect.evidence import TableContainer 

11from credoai.utils.common import ValidationError 

12 

13 

14class ShapExplainer(Evaluator): 

15 """ 

16 This evaluator perform the calculation of shapley values for a dataset/model, 

17 leveraging the SHAP package. 

18 

19 It supports 2 types of assessments: 

20 

21 1. Overall statistics of the shap values across all samples: mean and mean(|x|) 

22 2. Individual shapley values for a list of samples 

23 

24 Sampling 

25 -------- 

26 In order to speed up computation time, at the stage in which the SHAP explainer is 

27 initialized, a down sampled version of the dataset is passed to the `Explainer` 

28 object as background data. This is only affecting the calculation of the reference 

29 value, the calculation of the shap values is still performed on the full dataset. 

30 

31 Two strategies for down sampling are provided: 

32 

33 1. Random sampling (the default strategy): the amount of samples can be specified 

34 by the user. 

35 2. Kmeans: summarizes a dataset with k mean centroids, weighted by the number of 

36 data points they each represent. The amount of centroids can also be specified 

37 by the user. 

38 

39 There is no consensus on the optimal down sampling approach. For reference, see this 

40 conversation: https://github.com/slundberg/shap/issues/1018 

41 

42 

43 Categorical variables 

44 --------------------- 

45 The interpretation of the results for categorical variables can be more challenging, and 

46 dependent on the type of encoding utilized. Ordinal or one/hot encoding can be hard to 

47 interpret. 

48 

49 There is no agreement as to what is the best strategy as far as categorical variables are 

50 concerned. A good discussion on this can be found here: https://github.com/slundberg/shap/issues/451 

51 

52 No restriction on feature type is imposed by the evaluator, so user discretion in the 

53 interpretation of shap values for categorical variables is advised. 

54 

55 

56 Parameters 

57 ---------- 

58 samples_ind : Optional[List[int]], optional 

59 List of row numbers representing the samples for which to extract individual 

60 shapley values. This must be a list of integer indices. The underlying SHAP 

61 library does not support non-integer indexing. 

62 background_samples: int, 

63 Amount of samples to be taken from the dataset in order to build the reference values. 

64 See documentation about sampling above. Unused if background_kmeans is not False. 

65 background_kmeans : Union[bool, int], optional 

66 If True, use SHAP kmeans to create a data summary to serve as background data for the 

67 SHAP explainer using 50 centroids by default. If an int is provided, 

68 that will be used as the number of centroids. If False, random sampling will take place. 

69 

70 

71 """ 

72 

73 required_artifacts = ["assessment_data", "model"] 

74 

75 def __init__( 

76 self, 

77 samples_ind: Optional[List[int]] = None, 

78 background_samples: int = 100, 

79 background_kmeans: Union[bool, int] = False, 

80 ): 

81 super().__init__() 

82 self.samples_ind = samples_ind 

83 self._validate_samples_ind() 

84 self.background_samples = background_samples 

85 self.background_kmeans = background_kmeans 

86 self.classes = [None] 

87 

88 def _validate_arguments(self): 

89 check_requirements_existence(self) 

90 

91 def _setup(self): 

92 self.X = self.assessment_data.X 

93 self.model = self.model 

94 return self 

95 

96 def evaluate(self): 

97 ## Overall stats 

98 self._setup_shap() 

99 self.results = [ 

100 TableContainer( 

101 self._get_overall_shap_contributions(), **self.get_container_info() 

102 ) 

103 ] 

104 

105 ## Sample specific results 

106 if self.samples_ind: 

107 ind_res = self._get_mult_sample_shapley_values() 

108 self.results += [TableContainer(ind_res, **self.get_container_info())] 

109 return self 

110 

111 def _setup_shap(self): 

112 """ 

113 Setup the explainer given the model and the feature dataset 

114 """ 

115 if self.background_kmeans: 

116 if type(self.background_kmeans) is int: 

117 centroids_num = self.background_kmeans 

118 else: 

119 centroids_num = 50 

120 data_summary = kmeans(self.X, centroids_num).data 

121 else: 

122 data_summary = self.X.sample(self.background_samples) 

123 # try to use the model-like, which will only work if it is a model 

124 # that shap supports 

125 try: 

126 explainer = Explainer(self.model.model_like, data_summary) 

127 except: 

128 explainer = Explainer(self.model.predict, data_summary) 

129 # Generating the actual values calling the specific Shap function 

130 self.shap_values = explainer(self.X) 

131 

132 # Define values dataframes and classes variables depending on 

133 # the shape of the returned values. This accounts for multi class 

134 # classification 

135 s_values = self.shap_values.values 

136 if len(s_values.shape) == 2: 

137 self.values_df = [pd.DataFrame(s_values)] 

138 elif len(s_values.shape) == 3: 

139 self.values_df = [ 

140 pd.DataFrame(s_values[:, :, i]) for i in range(s_values.shape[2]) 

141 ] 

142 self.classes = self.model.model_like.classes_ 

143 else: 

144 raise RuntimeError( 

145 f"Shap vales have unsupported format. Detected shape {s_values.shape}" 

146 ) 

147 return self 

148 

149 def _get_overall_shap_contributions(self) -> pd.DataFrame: 

150 """ 

151 Calculate overall SHAP contributions for a dataset. 

152 

153 The output of SHAP package provides Shapley values for each sample in a 

154 dataset. To summarize the contribution of each feature in a dataset, the 

155 samples contributions need to be aggregated. 

156 

157 For each of the features, this method provides: mean and the 

158 mean of the absolute value of the samples Shapley values. 

159 

160 Returns 

161 ------- 

162 pd.DataFrame 

163 Summary of the Shapley values across the full dataset. 

164 """ 

165 

166 shap_summaries = [ 

167 self._summarize_shap_values(frame) for frame in self.values_df 

168 ] 

169 if len(self.classes) > 1: 

170 for label, df in zip(self.classes, shap_summaries): 

171 df.assign(class_label=label) 

172 # fmt: off 

173 shap_summary = ( 

174 pd.concat(shap_summaries) 

175 .reset_index() 

176 .rename({"index": "feature_name"}, axis=1) 

177 ) 

178 # fmt: on 

179 shap_summary.name = "Summary of Shap statistics" 

180 return shap_summary 

181 

182 def _summarize_shap_values(self, shap_val: pd.DataFrame) -> pd.DataFrame: 

183 """ 

184 Summarize Shape values at a Dataset level. 

185 

186 Parameters 

187 ---------- 

188 shap_val : pd.DataFrame 

189 Table containing shap values, if the model output is multiclass, 

190 the table corresponds to the values for a single class. 

191 

192 Returns 

193 ------- 

194 pd.DataFrame 

195 Summarized shap values. 

196 """ 

197 shap_val.columns = self.shap_values.feature_names 

198 summaries = {"mean": np.mean, "mean(|x|)": lambda x: np.mean(np.abs(x))} 

199 results = map(lambda func: shap_val.apply(func), summaries.values()) 

200 # fmt: off 

201 final = ( 

202 pd.concat(results, axis=1) 

203 .set_axis(summaries.keys(), axis=1) 

204 .sort_values("mean(|x|)", ascending=False) 

205 ) 

206 # fmt: on 

207 final.name = "Summary of Shap statistics" 

208 return final 

209 

210 def _get_mult_sample_shapley_values(self) -> pd.DataFrame: 

211 """ 

212 Return shapley values for multiple samples from the dataset. 

213 

214 Returns 

215 ------- 

216 pd.DataFrame 

217 Columns: 

218 values -> shap values 

219 ref_value -> Reference value for the shap values 

220 (generally the same across the dataset) 

221 sample_pos -> Position of the sample in the dataset 

222 """ 

223 all_sample_shaps = [] 

224 for ind in self.samples_ind: 

225 sample_results = self._get_single_sample_values(self.shap_values[ind]) 

226 sample_results = sample_results.assign(sample_pos=ind) 

227 all_sample_shaps.append(sample_results) 

228 

229 res = pd.concat(all_sample_shaps) 

230 res.name = "Shap values for specific samples" 

231 return res 

232 

233 def _validate_samples_ind(self, limit=5): 

234 """ 

235 Enforce limit on maximum amount of samples for which to extract 

236 individual shap values. 

237 

238 A maximum number of samples is enforced, this is in order to constrain the 

239 amount of information in transit to Credo AI Platform, both for performance 

240 and security reasons. 

241 

242 Parameters 

243 ---------- 

244 limit : int, optional 

245 Max number of samples allowed, by default 5. 

246 

247 Raises 

248 ------ 

249 ValidationError 

250 """ 

251 if self.samples_ind is not None: 

252 if len(self.samples_ind) > limit: 

253 message = "The maximum amount of individual samples_ind allowed is 5." 

254 raise ValidationError(message) 

255 

256 def _get_single_sample_values(self, sample_shap: Explanation) -> pd.DataFrame: 

257 """ 

258 Returns shapley values for a specific sample in the dataset 

259 

260 Parameters 

261 ---------- 

262 shap_values : Explanation 

263 Explainer object output for a specific sample. 

264 sample_ind : int 

265 Position (row number) of the sample of interest in the dataset 

266 provided to the Explainer. 

267 

268 Returns 

269 ------- 

270 dict 

271 keys: values, ref_value 

272 Contains shapley values for the sample, and the reference value. 

273 The model prediction for the sample is equal to: ref_value + sum(values) 

274 """ 

275 

276 class_values = [] 

277 

278 if len(self.classes) == 1: 

279 return pd.DataFrame({"values": sample_shap.values}).assign( 

280 ref_value=sample_shap.base_values, 

281 column_names=self.shap_values.feature_names, 

282 ) 

283 

284 for label, cls in enumerate(self.classes): 

285 class_values.append( 

286 pd.DataFrame({"values": sample_shap.values[:, label]}).assign( 

287 class_label=cls, 

288 ref_value=sample_shap.base_values[label], 

289 column_names=self.shap_values.feature_names, 

290 ) 

291 ) 

292 return pd.concat(class_values)