Coverage for credoai/evaluators/shap_credoai.py: 84%
87 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-13 21:56 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-13 21:56 +0000
1import enum
2from typing import Dict, List, Optional, Union
4import numpy as np
5import pandas as pd
6from connect.evidence import TableContainer
7from shap import Explainer, Explanation, kmeans
9from credoai.evaluators.evaluator import Evaluator
10from credoai.evaluators.utils.validation import check_requirements_existence
11from credoai.utils.common import ValidationError
14class ShapExplainer(Evaluator):
15 """
16 This evaluator perform the calculation of shapley values for a dataset/model,
17 leveraging the SHAP package.
19 It supports 2 types of assessments:
21 1. Overall statistics of the shap values across all samples: mean and mean(|x|)
22 2. Individual shapley values for a list of samples
24 Sampling
25 --------
26 In order to speed up computation time, at the stage in which the SHAP explainer is
27 initialized, a down sampled version of the dataset is passed to the `Explainer`
28 object as background data. This is only affecting the calculation of the reference
29 value, the calculation of the shap values is still performed on the full dataset.
31 Two strategies for down sampling are provided:
33 1. Random sampling (the default strategy): the amount of samples can be specified
34 by the user.
35 2. Kmeans: summarizes a dataset with k mean centroids, weighted by the number of
36 data points they each represent. The amount of centroids can also be specified
37 by the user.
39 There is no consensus on the optimal down sampling approach. For reference, see this
40 conversation: https://github.com/slundberg/shap/issues/1018
43 Categorical variables
44 ---------------------
45 The interpretation of the results for categorical variables can be more challenging, and
46 dependent on the type of encoding utilized. Ordinal or one/hot encoding can be hard to
47 interpret.
49 There is no agreement as to what is the best strategy as far as categorical variables are
50 concerned. A good discussion on this can be found here: https://github.com/slundberg/shap/issues/451
52 No restriction on feature type is imposed by the evaluator, so user discretion in the
53 interpretation of shap values for categorical variables is advised.
56 Parameters
57 ----------
58 samples_ind : Optional[List[int]], optional
59 List of row numbers representing the samples for which to extract individual
60 shapley values. This must be a list of integer indices. The underlying SHAP
61 library does not support non-integer indexing.
62 background_samples: int,
63 Amount of samples to be taken from the dataset in order to build the reference values.
64 See documentation about sampling above. Unused if background_kmeans is not False.
65 background_kmeans : Union[bool, int], optional
66 If True, use SHAP kmeans to create a data summary to serve as background data for the
67 SHAP explainer using 50 centroids by default. If an int is provided,
68 that will be used as the number of centroids. If False, random sampling will take place.
71 """
73 required_artifacts = ["assessment_data", "model"]
75 def __init__(
76 self,
77 samples_ind: Optional[List[int]] = None,
78 background_samples: int = 100,
79 background_kmeans: Union[bool, int] = False,
80 ):
81 super().__init__()
82 self.samples_ind = samples_ind
83 self._validate_samples_ind()
84 self.background_samples = background_samples
85 self.background_kmeans = background_kmeans
86 self.classes = [None]
88 def _validate_arguments(self):
89 check_requirements_existence(self)
91 def _setup(self):
92 self.X = self.assessment_data.X
93 self.model = self.model
94 return self
96 def evaluate(self):
97 ## Overall stats
98 self._setup_shap()
99 self.results = [
100 TableContainer(self._get_overall_shap_contributions(), **self.get_info())
101 ]
103 ## Sample specific results
104 if self.samples_ind:
105 ind_res = self._get_mult_sample_shapley_values()
106 self.results += [TableContainer(ind_res, **self.get_info())]
107 return self
109 def _setup_shap(self):
110 """
111 Setup the explainer given the model and the feature dataset
112 """
113 if self.background_kmeans:
114 if type(self.background_kmeans) is int:
115 centroids_num = self.background_kmeans
116 else:
117 centroids_num = 50
118 data_summary = kmeans(self.X, centroids_num).data
119 else:
120 data_summary = self.X.sample(self.background_samples)
121 # try to use the model-like, which will only work if it is a model
122 # that shap supports
123 try:
124 explainer = Explainer(self.model.model_like, data_summary)
125 except:
126 explainer = Explainer(self.model.predict, data_summary)
127 # Generating the actual values calling the specific Shap function
128 self.shap_values = explainer(self.X)
130 # Define values dataframes and classes variables depending on
131 # the shape of the returned values. This accounts for multi class
132 # classification
133 s_values = self.shap_values.values
134 if len(s_values.shape) == 2:
135 self.values_df = [pd.DataFrame(s_values)]
136 elif len(s_values.shape) == 3:
137 self.values_df = [
138 pd.DataFrame(s_values[:, :, i]) for i in range(s_values.shape[2])
139 ]
140 self.classes = self.model.model_like.classes_
141 else:
142 raise RuntimeError(
143 f"Shap vales have unsupported format. Detected shape {s_values.shape}"
144 )
145 return self
147 def _get_overall_shap_contributions(self) -> pd.DataFrame:
148 """
149 Calculate overall SHAP contributions for a dataset.
151 The output of SHAP package provides Shapley values for each sample in a
152 dataset. To summarize the contribution of each feature in a dataset, the
153 samples contributions need to be aggregated.
155 For each of the features, this method provides: mean and the
156 mean of the absolute value of the samples Shapley values.
158 Returns
159 -------
160 pd.DataFrame
161 Summary of the Shapley values across the full dataset.
162 """
164 shap_summaries = [
165 self._summarize_shap_values(frame) for frame in self.values_df
166 ]
167 if len(self.classes) > 1:
168 for label, df in zip(self.classes, shap_summaries):
169 df.assign(class_label=label)
170 # fmt: off
171 shap_summary = (
172 pd.concat(shap_summaries)
173 .reset_index()
174 .rename({"index": "feature_name"}, axis=1)
175 )
176 # fmt: on
177 shap_summary.name = "Summary of Shap statistics"
178 return shap_summary
180 def _summarize_shap_values(self, shap_val: pd.DataFrame) -> pd.DataFrame:
181 """
182 Summarize Shape values at a Dataset level.
184 Parameters
185 ----------
186 shap_val : pd.DataFrame
187 Table containing shap values, if the model output is multiclass,
188 the table corresponds to the values for a single class.
190 Returns
191 -------
192 pd.DataFrame
193 Summarized shap values.
194 """
195 shap_val.columns = self.shap_values.feature_names
196 summaries = {"mean": np.mean, "mean(|x|)": lambda x: np.mean(np.abs(x))}
197 results = map(lambda func: shap_val.apply(func), summaries.values())
198 # fmt: off
199 final = (
200 pd.concat(results, axis=1)
201 .set_axis(summaries.keys(), axis=1)
202 .sort_values("mean(|x|)", ascending=False)
203 )
204 # fmt: on
205 final.name = "Summary of Shap statistics"
206 return final
208 def _get_mult_sample_shapley_values(self) -> pd.DataFrame:
209 """
210 Return shapley values for multiple samples from the dataset.
212 Returns
213 -------
214 pd.DataFrame
215 Columns:
216 values -> shap values
217 ref_value -> Reference value for the shap values
218 (generally the same across the dataset)
219 sample_pos -> Position of the sample in the dataset
220 """
221 all_sample_shaps = []
222 for ind in self.samples_ind:
223 sample_results = self._get_single_sample_values(self.shap_values[ind])
224 sample_results = sample_results.assign(sample_pos=ind)
225 all_sample_shaps.append(sample_results)
227 res = pd.concat(all_sample_shaps)
228 res.name = "Shap values for specific samples"
229 return res
231 def _validate_samples_ind(self, limit=5):
232 """
233 Enforce limit on maximum amount of samples for which to extract
234 individual shap values.
236 A maximum number of samples is enforced, this is in order to constrain the
237 amount of information in transit to Credo AI Platform, both for performance
238 and security reasons.
240 Parameters
241 ----------
242 limit : int, optional
243 Max number of samples allowed, by default 5.
245 Raises
246 ------
247 ValidationError
248 """
249 if self.samples_ind is not None:
250 if len(self.samples_ind) > limit:
251 message = "The maximum amount of individual samples_ind allowed is 5."
252 raise ValidationError(message)
254 def _get_single_sample_values(self, sample_shap: Explanation) -> pd.DataFrame:
255 """
256 Returns shapley values for a specific sample in the dataset
258 Parameters
259 ----------
260 shap_values : Explanation
261 Explainer object output for a specific sample.
262 sample_ind : int
263 Position (row number) of the sample of interest in the dataset
264 provided to the Explainer.
266 Returns
267 -------
268 dict
269 keys: values, ref_value
270 Contains shapley values for the sample, and the reference value.
271 The model prediction for the sample is equal to: ref_value + sum(values)
272 """
274 class_values = []
276 if len(self.classes) == 1:
277 return pd.DataFrame({"values": sample_shap.values}).assign(
278 ref_value=sample_shap.base_values,
279 column_names=self.shap_values.feature_names,
280 )
282 for label, cls in enumerate(self.classes):
283 class_values.append(
284 pd.DataFrame({"values": sample_shap.values[:, label]}).assign(
285 class_label=cls,
286 ref_value=sample_shap.base_values[label],
287 column_names=self.shap_values.feature_names,
288 )
289 )
290 return pd.concat(class_values)