Coverage for credoai/artifacts/data/base_data.py: 89%
105 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-08 07:32 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-08 07:32 +0000
1"""Abstract class for the data artifacts used by `Lens`"""
2# Data is a lightweight wrapper that stores data
3import itertools
4from abc import ABC, abstractmethod
5from typing import Optional, Union
7import pandas as pd
9from credoai.utils.common import ValidationError
10from credoai.utils.model_utils import type_of_target
12from copy import deepcopy
15class Data(ABC):
16 """Class wrapper around data-to-be-assessed
18 Data is passed to Lens for certain assessments.
20 Data serves as an adapter between datasets
21 and the evaluators in Lens.
23 Parameters
24 -------------
25 type : str
26 Type of the dataset
27 name : str
28 Label of the dataset
29 X : to-be-defined by children
30 Dataset
31 y : to-be-defined by children
32 Outcome
33 sensitive_features : pd.Series, pd.DataFrame, optional
34 Sensitive Features, which will be used for disaggregating performance
35 metrics. This can be the columns you want to perform segmentation analysis on, or
36 a feature related to fairness like 'race' or 'gender'
37 sensitive_intersections : bool, list
38 Whether to add intersections of sensitive features. If True, add all possible
39 intersections. If list, only create intersections from specified sensitive features.
40 If False, no intersections will be created. Defaults False
41 """
43 def __init__(
44 self,
45 type: str,
46 name: str,
47 X=None,
48 y=None,
49 sensitive_features=None,
50 sensitive_intersections: Union[bool, list] = False,
51 ):
52 if isinstance(name, str):
53 self.name = name
54 else:
55 raise ValidationError("{Name} must be a string")
56 self.X = X
57 self.y = y
58 self.sensitive_features = sensitive_features
59 self._validate_inputs()
60 self._process_inputs(sensitive_intersections)
61 self._validate_processing()
62 self._active_sensitive_feature: Optional[str] = None
64 @property
65 def active_sens_feat(self):
66 """
67 Defines which sensitive feature an evaluator will be working on.
69 In combination with the property sensitive_feature this effectively creates
70 a view of a specific artifact.
71 """
72 if self._active_sensitive_feature is None:
73 self._active_sensitive_feature = self.sensitive_features.columns[0]
74 return self._active_sensitive_feature
76 @active_sens_feat.setter
77 def active_sens_feat(self, value: str):
78 """
79 Sets the active_sens_feat value.
81 Parameters
82 ----------
83 value : str
84 Name of the sensitive feature column an evaluator has to operate on.
85 """
86 self._active_sensitive_feature = value
88 @property
89 def sensitive_feature(self):
90 """
91 Reveals the sensitive feature defined by active_sens_feat.
93 This is generally called from within an evaluator, when it is working
94 on a single sensitive feature.
95 """
96 return self.sensitive_features[self.active_sens_feat]
98 @property
99 def y_type(self):
100 return type_of_target(self.y)
102 @property
103 def data(self):
104 data = {"X": self.X, "y": self.y, "sensitive_features": self.sensitive_features}
105 return data
107 def _process_inputs(self, sensitive_intersections):
108 if self.X is not None:
109 self.X = self._process_X(self.X)
110 if self.y is not None:
111 self.y = self._process_y(self.y)
112 if self.sensitive_features is not None:
113 self.sensitive_features = self._process_sensitive(
114 deepcopy(self.sensitive_features), sensitive_intersections
115 )
117 def _process_sensitive(self, sensitive_features, sensitive_intersections):
118 """
119 Formats sensitive features
121 Parameters
122 ----------
123 sensitive_features :
124 Sensitive features as provided by a user. Any format that can be constrained
125 in a dataframe is accepted.
126 sensitive_intersections : Bool
127 Indicates whether to create intersections among sensitive features.
129 Returns
130 -------
131 _type_
132 _description_
133 """
134 df = pd.DataFrame(sensitive_features)
135 # add intersections if asked for
136 features = df.columns
137 if sensitive_intersections is False or len(features) == 1:
138 return df
139 elif sensitive_intersections is True:
140 sensitive_intersections = features
141 intersections = []
142 for i in range(2, len(features) + 1):
143 intersections += list(itertools.combinations(sensitive_intersections, i))
144 for intersection in intersections:
145 tmp = df[intersection[0]]
146 for col in intersection[1:]:
147 tmp = tmp.str.cat(df[col].astype(str), sep="_")
148 label = "_".join(intersection)
149 df[label] = tmp
150 return df
152 def _process_X(self, X):
153 return X
155 def _process_y(self, y):
156 return y
158 def _validate_inputs(self):
159 """Basic input validation"""
160 if self.X is not None:
161 self._validate_X()
162 if self.y is not None:
163 self._validate_y()
164 if self.sensitive_features is not None:
165 self._validate_sensitive()
167 def _validate_sensitive(self):
168 """Sensitive features validation"""
169 # Validate the types
170 if not isinstance(self.sensitive_features, (pd.Series, pd.DataFrame)):
171 raise ValidationError(
172 "Sensitive_feature type is '"
173 + type(self.sensitive_features).__name__
174 + "' but the required type is either pd.DataFrame or pd.Series"
175 )
176 if self.X is not None:
177 if len(self.X) != len(self.sensitive_features):
178 raise ValidationError(
179 "X and sensitive_features are not the same length. "
180 + f"X Length: {len(self.X)}, sensitive_features Length: {len(self.y)}"
181 )
182 if isinstance(self.X, (pd.Series, pd.DataFrame)) and not self.X.index.equals(
183 self.sensitive_features.index
184 ):
185 raise ValidationError("X and sensitive features must have the same index")
187 if isinstance(self.sensitive_features, pd.Series):
188 if not hasattr(self.sensitive_features, "name"):
189 raise ValidationError("Feature Series should have a name attribute")
191 @abstractmethod
192 def _validate_X(self):
193 pass
195 @abstractmethod
196 def _validate_y(self):
197 pass
199 def _validate_processing(self):
200 """Validation of processed data"""
201 if self.X is not None:
202 self._validate_processed_X()
203 if self.y is not None:
204 self._validate_processed_y()
205 if self.sensitive_features is not None:
206 self._validate_processed_sensitive()
208 def _validate_processed_X(self):
209 pass
211 def _validate_processed_y(self):
212 pass
214 def _validate_processed_sensitive(self):
215 """Validation of processed sensitive features"""
216 for col_name, col in self.sensitive_features.iteritems():
217 unique_values = col.unique()
218 if len(unique_values) == 1:
219 raise ValidationError(
220 f"Sensitive Feature column {col_name} must have more "
221 f"than one unique value. Only found one value: {unique_values[0]}"
222 )