Coverage for credoai/evaluators/shap.py: 84%
87 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-08 07:32 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-08 07:32 +0000
1import enum
2from typing import Dict, List, Optional, Union
4import numpy as np
5import pandas as pd
6from shap import Explainer, Explanation, kmeans
8from credoai.evaluators import Evaluator
9from credoai.evaluators.utils.validation import check_requirements_existence
10from connect.evidence import TableContainer
11from credoai.utils.common import ValidationError
14class ShapExplainer(Evaluator):
15 """
16 This evaluator perform the calculation of shapley values for a dataset/model,
17 leveraging the SHAP package.
19 It supports 2 types of assessments:
21 1. Overall statistics of the shap values across all samples: mean and mean(|x|)
22 2. Individual shapley values for a list of samples
24 Sampling
25 --------
26 In order to speed up computation time, at the stage in which the SHAP explainer is
27 initialized, a down sampled version of the dataset is passed to the `Explainer`
28 object as background data. This is only affecting the calculation of the reference
29 value, the calculation of the shap values is still performed on the full dataset.
31 Two strategies for down sampling are provided:
33 1. Random sampling (the default strategy): the amount of samples can be specified
34 by the user.
35 2. Kmeans: summarizes a dataset with k mean centroids, weighted by the number of
36 data points they each represent. The amount of centroids can also be specified
37 by the user.
39 There is no consensus on the optimal down sampling approach. For reference, see this
40 conversation: https://github.com/slundberg/shap/issues/1018
43 Categorical variables
44 ---------------------
45 The interpretation of the results for categorical variables can be more challenging, and
46 dependent on the type of encoding utilized. Ordinal or one/hot encoding can be hard to
47 interpret.
49 There is no agreement as to what is the best strategy as far as categorical variables are
50 concerned. A good discussion on this can be found here: https://github.com/slundberg/shap/issues/451
52 No restriction on feature type is imposed by the evaluator, so user discretion in the
53 interpretation of shap values for categorical variables is advised.
56 Parameters
57 ----------
58 samples_ind : Optional[List[int]], optional
59 List of row numbers representing the samples for which to extract individual
60 shapley values. This must be a list of integer indices. The underlying SHAP
61 library does not support non-integer indexing.
62 background_samples: int,
63 Amount of samples to be taken from the dataset in order to build the reference values.
64 See documentation about sampling above. Unused if background_kmeans is not False.
65 background_kmeans : Union[bool, int], optional
66 If True, use SHAP kmeans to create a data summary to serve as background data for the
67 SHAP explainer using 50 centroids by default. If an int is provided,
68 that will be used as the number of centroids. If False, random sampling will take place.
71 """
73 required_artifacts = ["assessment_data", "model"]
75 def __init__(
76 self,
77 samples_ind: Optional[List[int]] = None,
78 background_samples: int = 100,
79 background_kmeans: Union[bool, int] = False,
80 ):
81 super().__init__()
82 self.samples_ind = samples_ind
83 self._validate_samples_ind()
84 self.background_samples = background_samples
85 self.background_kmeans = background_kmeans
86 self.classes = [None]
88 def _validate_arguments(self):
89 check_requirements_existence(self)
91 def _setup(self):
92 self.X = self.assessment_data.X
93 self.model = self.model
94 return self
96 def evaluate(self):
97 ## Overall stats
98 self._setup_shap()
99 self.results = [
100 TableContainer(
101 self._get_overall_shap_contributions(), **self.get_container_info()
102 )
103 ]
105 ## Sample specific results
106 if self.samples_ind:
107 ind_res = self._get_mult_sample_shapley_values()
108 self.results += [TableContainer(ind_res, **self.get_container_info())]
109 return self
111 def _setup_shap(self):
112 """
113 Setup the explainer given the model and the feature dataset
114 """
115 if self.background_kmeans:
116 if type(self.background_kmeans) is int:
117 centroids_num = self.background_kmeans
118 else:
119 centroids_num = 50
120 data_summary = kmeans(self.X, centroids_num).data
121 else:
122 data_summary = self.X.sample(self.background_samples)
123 # try to use the model-like, which will only work if it is a model
124 # that shap supports
125 try:
126 explainer = Explainer(self.model.model_like, data_summary)
127 except:
128 explainer = Explainer(self.model.predict, data_summary)
129 # Generating the actual values calling the specific Shap function
130 self.shap_values = explainer(self.X)
132 # Define values dataframes and classes variables depending on
133 # the shape of the returned values. This accounts for multi class
134 # classification
135 s_values = self.shap_values.values
136 if len(s_values.shape) == 2:
137 self.values_df = [pd.DataFrame(s_values)]
138 elif len(s_values.shape) == 3:
139 self.values_df = [
140 pd.DataFrame(s_values[:, :, i]) for i in range(s_values.shape[2])
141 ]
142 self.classes = self.model.model_like.classes_
143 else:
144 raise RuntimeError(
145 f"Shap vales have unsupported format. Detected shape {s_values.shape}"
146 )
147 return self
149 def _get_overall_shap_contributions(self) -> pd.DataFrame:
150 """
151 Calculate overall SHAP contributions for a dataset.
153 The output of SHAP package provides Shapley values for each sample in a
154 dataset. To summarize the contribution of each feature in a dataset, the
155 samples contributions need to be aggregated.
157 For each of the features, this method provides: mean and the
158 mean of the absolute value of the samples Shapley values.
160 Returns
161 -------
162 pd.DataFrame
163 Summary of the Shapley values across the full dataset.
164 """
166 shap_summaries = [
167 self._summarize_shap_values(frame) for frame in self.values_df
168 ]
169 if len(self.classes) > 1:
170 for label, df in zip(self.classes, shap_summaries):
171 df.assign(class_label=label)
172 # fmt: off
173 shap_summary = (
174 pd.concat(shap_summaries)
175 .reset_index()
176 .rename({"index": "feature_name"}, axis=1)
177 )
178 # fmt: on
179 shap_summary.name = "Summary of Shap statistics"
180 return shap_summary
182 def _summarize_shap_values(self, shap_val: pd.DataFrame) -> pd.DataFrame:
183 """
184 Summarize Shape values at a Dataset level.
186 Parameters
187 ----------
188 shap_val : pd.DataFrame
189 Table containing shap values, if the model output is multiclass,
190 the table corresponds to the values for a single class.
192 Returns
193 -------
194 pd.DataFrame
195 Summarized shap values.
196 """
197 shap_val.columns = self.shap_values.feature_names
198 summaries = {"mean": np.mean, "mean(|x|)": lambda x: np.mean(np.abs(x))}
199 results = map(lambda func: shap_val.apply(func), summaries.values())
200 # fmt: off
201 final = (
202 pd.concat(results, axis=1)
203 .set_axis(summaries.keys(), axis=1)
204 .sort_values("mean(|x|)", ascending=False)
205 )
206 # fmt: on
207 final.name = "Summary of Shap statistics"
208 return final
210 def _get_mult_sample_shapley_values(self) -> pd.DataFrame:
211 """
212 Return shapley values for multiple samples from the dataset.
214 Returns
215 -------
216 pd.DataFrame
217 Columns:
218 values -> shap values
219 ref_value -> Reference value for the shap values
220 (generally the same across the dataset)
221 sample_pos -> Position of the sample in the dataset
222 """
223 all_sample_shaps = []
224 for ind in self.samples_ind:
225 sample_results = self._get_single_sample_values(self.shap_values[ind])
226 sample_results = sample_results.assign(sample_pos=ind)
227 all_sample_shaps.append(sample_results)
229 res = pd.concat(all_sample_shaps)
230 res.name = "Shap values for specific samples"
231 return res
233 def _validate_samples_ind(self, limit=5):
234 """
235 Enforce limit on maximum amount of samples for which to extract
236 individual shap values.
238 A maximum number of samples is enforced, this is in order to constrain the
239 amount of information in transit to Credo AI Platform, both for performance
240 and security reasons.
242 Parameters
243 ----------
244 limit : int, optional
245 Max number of samples allowed, by default 5.
247 Raises
248 ------
249 ValidationError
250 """
251 if self.samples_ind is not None:
252 if len(self.samples_ind) > limit:
253 message = "The maximum amount of individual samples_ind allowed is 5."
254 raise ValidationError(message)
256 def _get_single_sample_values(self, sample_shap: Explanation) -> pd.DataFrame:
257 """
258 Returns shapley values for a specific sample in the dataset
260 Parameters
261 ----------
262 shap_values : Explanation
263 Explainer object output for a specific sample.
264 sample_ind : int
265 Position (row number) of the sample of interest in the dataset
266 provided to the Explainer.
268 Returns
269 -------
270 dict
271 keys: values, ref_value
272 Contains shapley values for the sample, and the reference value.
273 The model prediction for the sample is equal to: ref_value + sum(values)
274 """
276 class_values = []
278 if len(self.classes) == 1:
279 return pd.DataFrame({"values": sample_shap.values}).assign(
280 ref_value=sample_shap.base_values,
281 column_names=self.shap_values.feature_names,
282 )
284 for label, cls in enumerate(self.classes):
285 class_values.append(
286 pd.DataFrame({"values": sample_shap.values[:, label]}).assign(
287 class_label=cls,
288 ref_value=sample_shap.base_values[label],
289 column_names=self.shap_values.feature_names,
290 )
291 )
292 return pd.concat(class_values)