Coverage for credoai/evaluators/shap_credoai.py: 84%

87 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-13 21:56 +0000

1import enum 

2from typing import Dict, List, Optional, Union 

3 

4import numpy as np 

5import pandas as pd 

6from connect.evidence import TableContainer 

7from shap import Explainer, Explanation, kmeans 

8 

9from credoai.evaluators.evaluator import Evaluator 

10from credoai.evaluators.utils.validation import check_requirements_existence 

11from credoai.utils.common import ValidationError 

12 

13 

14class ShapExplainer(Evaluator): 

15 """ 

16 This evaluator perform the calculation of shapley values for a dataset/model, 

17 leveraging the SHAP package. 

18 

19 It supports 2 types of assessments: 

20 

21 1. Overall statistics of the shap values across all samples: mean and mean(|x|) 

22 2. Individual shapley values for a list of samples 

23 

24 Sampling 

25 -------- 

26 In order to speed up computation time, at the stage in which the SHAP explainer is 

27 initialized, a down sampled version of the dataset is passed to the `Explainer` 

28 object as background data. This is only affecting the calculation of the reference 

29 value, the calculation of the shap values is still performed on the full dataset. 

30 

31 Two strategies for down sampling are provided: 

32 

33 1. Random sampling (the default strategy): the amount of samples can be specified 

34 by the user. 

35 2. Kmeans: summarizes a dataset with k mean centroids, weighted by the number of 

36 data points they each represent. The amount of centroids can also be specified 

37 by the user. 

38 

39 There is no consensus on the optimal down sampling approach. For reference, see this 

40 conversation: https://github.com/slundberg/shap/issues/1018 

41 

42 

43 Categorical variables 

44 --------------------- 

45 The interpretation of the results for categorical variables can be more challenging, and 

46 dependent on the type of encoding utilized. Ordinal or one/hot encoding can be hard to 

47 interpret. 

48 

49 There is no agreement as to what is the best strategy as far as categorical variables are 

50 concerned. A good discussion on this can be found here: https://github.com/slundberg/shap/issues/451 

51 

52 No restriction on feature type is imposed by the evaluator, so user discretion in the 

53 interpretation of shap values for categorical variables is advised. 

54 

55 

56 Parameters 

57 ---------- 

58 samples_ind : Optional[List[int]], optional 

59 List of row numbers representing the samples for which to extract individual 

60 shapley values. This must be a list of integer indices. The underlying SHAP 

61 library does not support non-integer indexing. 

62 background_samples: int, 

63 Amount of samples to be taken from the dataset in order to build the reference values. 

64 See documentation about sampling above. Unused if background_kmeans is not False. 

65 background_kmeans : Union[bool, int], optional 

66 If True, use SHAP kmeans to create a data summary to serve as background data for the 

67 SHAP explainer using 50 centroids by default. If an int is provided, 

68 that will be used as the number of centroids. If False, random sampling will take place. 

69 

70 

71 """ 

72 

73 required_artifacts = ["assessment_data", "model"] 

74 

75 def __init__( 

76 self, 

77 samples_ind: Optional[List[int]] = None, 

78 background_samples: int = 100, 

79 background_kmeans: Union[bool, int] = False, 

80 ): 

81 super().__init__() 

82 self.samples_ind = samples_ind 

83 self._validate_samples_ind() 

84 self.background_samples = background_samples 

85 self.background_kmeans = background_kmeans 

86 self.classes = [None] 

87 

88 def _validate_arguments(self): 

89 check_requirements_existence(self) 

90 

91 def _setup(self): 

92 self.X = self.assessment_data.X 

93 self.model = self.model 

94 return self 

95 

96 def evaluate(self): 

97 ## Overall stats 

98 self._setup_shap() 

99 self.results = [ 

100 TableContainer(self._get_overall_shap_contributions(), **self.get_info()) 

101 ] 

102 

103 ## Sample specific results 

104 if self.samples_ind: 

105 ind_res = self._get_mult_sample_shapley_values() 

106 self.results += [TableContainer(ind_res, **self.get_info())] 

107 return self 

108 

109 def _setup_shap(self): 

110 """ 

111 Setup the explainer given the model and the feature dataset 

112 """ 

113 if self.background_kmeans: 

114 if type(self.background_kmeans) is int: 

115 centroids_num = self.background_kmeans 

116 else: 

117 centroids_num = 50 

118 data_summary = kmeans(self.X, centroids_num).data 

119 else: 

120 data_summary = self.X.sample(self.background_samples) 

121 # try to use the model-like, which will only work if it is a model 

122 # that shap supports 

123 try: 

124 explainer = Explainer(self.model.model_like, data_summary) 

125 except: 

126 explainer = Explainer(self.model.predict, data_summary) 

127 # Generating the actual values calling the specific Shap function 

128 self.shap_values = explainer(self.X) 

129 

130 # Define values dataframes and classes variables depending on 

131 # the shape of the returned values. This accounts for multi class 

132 # classification 

133 s_values = self.shap_values.values 

134 if len(s_values.shape) == 2: 

135 self.values_df = [pd.DataFrame(s_values)] 

136 elif len(s_values.shape) == 3: 

137 self.values_df = [ 

138 pd.DataFrame(s_values[:, :, i]) for i in range(s_values.shape[2]) 

139 ] 

140 self.classes = self.model.model_like.classes_ 

141 else: 

142 raise RuntimeError( 

143 f"Shap vales have unsupported format. Detected shape {s_values.shape}" 

144 ) 

145 return self 

146 

147 def _get_overall_shap_contributions(self) -> pd.DataFrame: 

148 """ 

149 Calculate overall SHAP contributions for a dataset. 

150 

151 The output of SHAP package provides Shapley values for each sample in a 

152 dataset. To summarize the contribution of each feature in a dataset, the 

153 samples contributions need to be aggregated. 

154 

155 For each of the features, this method provides: mean and the 

156 mean of the absolute value of the samples Shapley values. 

157 

158 Returns 

159 ------- 

160 pd.DataFrame 

161 Summary of the Shapley values across the full dataset. 

162 """ 

163 

164 shap_summaries = [ 

165 self._summarize_shap_values(frame) for frame in self.values_df 

166 ] 

167 if len(self.classes) > 1: 

168 for label, df in zip(self.classes, shap_summaries): 

169 df.assign(class_label=label) 

170 # fmt: off 

171 shap_summary = ( 

172 pd.concat(shap_summaries) 

173 .reset_index() 

174 .rename({"index": "feature_name"}, axis=1) 

175 ) 

176 # fmt: on 

177 shap_summary.name = "Summary of Shap statistics" 

178 return shap_summary 

179 

180 def _summarize_shap_values(self, shap_val: pd.DataFrame) -> pd.DataFrame: 

181 """ 

182 Summarize Shape values at a Dataset level. 

183 

184 Parameters 

185 ---------- 

186 shap_val : pd.DataFrame 

187 Table containing shap values, if the model output is multiclass, 

188 the table corresponds to the values for a single class. 

189 

190 Returns 

191 ------- 

192 pd.DataFrame 

193 Summarized shap values. 

194 """ 

195 shap_val.columns = self.shap_values.feature_names 

196 summaries = {"mean": np.mean, "mean(|x|)": lambda x: np.mean(np.abs(x))} 

197 results = map(lambda func: shap_val.apply(func), summaries.values()) 

198 # fmt: off 

199 final = ( 

200 pd.concat(results, axis=1) 

201 .set_axis(summaries.keys(), axis=1) 

202 .sort_values("mean(|x|)", ascending=False) 

203 ) 

204 # fmt: on 

205 final.name = "Summary of Shap statistics" 

206 return final 

207 

208 def _get_mult_sample_shapley_values(self) -> pd.DataFrame: 

209 """ 

210 Return shapley values for multiple samples from the dataset. 

211 

212 Returns 

213 ------- 

214 pd.DataFrame 

215 Columns: 

216 values -> shap values 

217 ref_value -> Reference value for the shap values 

218 (generally the same across the dataset) 

219 sample_pos -> Position of the sample in the dataset 

220 """ 

221 all_sample_shaps = [] 

222 for ind in self.samples_ind: 

223 sample_results = self._get_single_sample_values(self.shap_values[ind]) 

224 sample_results = sample_results.assign(sample_pos=ind) 

225 all_sample_shaps.append(sample_results) 

226 

227 res = pd.concat(all_sample_shaps) 

228 res.name = "Shap values for specific samples" 

229 return res 

230 

231 def _validate_samples_ind(self, limit=5): 

232 """ 

233 Enforce limit on maximum amount of samples for which to extract 

234 individual shap values. 

235 

236 A maximum number of samples is enforced, this is in order to constrain the 

237 amount of information in transit to Credo AI Platform, both for performance 

238 and security reasons. 

239 

240 Parameters 

241 ---------- 

242 limit : int, optional 

243 Max number of samples allowed, by default 5. 

244 

245 Raises 

246 ------ 

247 ValidationError 

248 """ 

249 if self.samples_ind is not None: 

250 if len(self.samples_ind) > limit: 

251 message = "The maximum amount of individual samples_ind allowed is 5." 

252 raise ValidationError(message) 

253 

254 def _get_single_sample_values(self, sample_shap: Explanation) -> pd.DataFrame: 

255 """ 

256 Returns shapley values for a specific sample in the dataset 

257 

258 Parameters 

259 ---------- 

260 shap_values : Explanation 

261 Explainer object output for a specific sample. 

262 sample_ind : int 

263 Position (row number) of the sample of interest in the dataset 

264 provided to the Explainer. 

265 

266 Returns 

267 ------- 

268 dict 

269 keys: values, ref_value 

270 Contains shapley values for the sample, and the reference value. 

271 The model prediction for the sample is equal to: ref_value + sum(values) 

272 """ 

273 

274 class_values = [] 

275 

276 if len(self.classes) == 1: 

277 return pd.DataFrame({"values": sample_shap.values}).assign( 

278 ref_value=sample_shap.base_values, 

279 column_names=self.shap_values.feature_names, 

280 ) 

281 

282 for label, cls in enumerate(self.classes): 

283 class_values.append( 

284 pd.DataFrame({"values": sample_shap.values[:, label]}).assign( 

285 class_label=cls, 

286 ref_value=sample_shap.base_values[label], 

287 column_names=self.shap_values.feature_names, 

288 ) 

289 ) 

290 return pd.concat(class_values)