Coverage for credoai/artifacts/data/base_data.py: 89%

105 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-08 07:32 +0000

1"""Abstract class for the data artifacts used by `Lens`""" 

2# Data is a lightweight wrapper that stores data 

3import itertools 

4from abc import ABC, abstractmethod 

5from typing import Optional, Union 

6 

7import pandas as pd 

8 

9from credoai.utils.common import ValidationError 

10from credoai.utils.model_utils import type_of_target 

11 

12from copy import deepcopy 

13 

14 

15class Data(ABC): 

16 """Class wrapper around data-to-be-assessed 

17 

18 Data is passed to Lens for certain assessments. 

19 

20 Data serves as an adapter between datasets 

21 and the evaluators in Lens. 

22 

23 Parameters 

24 ------------- 

25 type : str 

26 Type of the dataset 

27 name : str 

28 Label of the dataset 

29 X : to-be-defined by children 

30 Dataset 

31 y : to-be-defined by children 

32 Outcome 

33 sensitive_features : pd.Series, pd.DataFrame, optional 

34 Sensitive Features, which will be used for disaggregating performance 

35 metrics. This can be the columns you want to perform segmentation analysis on, or 

36 a feature related to fairness like 'race' or 'gender' 

37 sensitive_intersections : bool, list 

38 Whether to add intersections of sensitive features. If True, add all possible 

39 intersections. If list, only create intersections from specified sensitive features. 

40 If False, no intersections will be created. Defaults False 

41 """ 

42 

43 def __init__( 

44 self, 

45 type: str, 

46 name: str, 

47 X=None, 

48 y=None, 

49 sensitive_features=None, 

50 sensitive_intersections: Union[bool, list] = False, 

51 ): 

52 if isinstance(name, str): 

53 self.name = name 

54 else: 

55 raise ValidationError("{Name} must be a string") 

56 self.X = X 

57 self.y = y 

58 self.sensitive_features = sensitive_features 

59 self._validate_inputs() 

60 self._process_inputs(sensitive_intersections) 

61 self._validate_processing() 

62 self._active_sensitive_feature: Optional[str] = None 

63 

64 @property 

65 def active_sens_feat(self): 

66 """ 

67 Defines which sensitive feature an evaluator will be working on. 

68 

69 In combination with the property sensitive_feature this effectively creates 

70 a view of a specific artifact. 

71 """ 

72 if self._active_sensitive_feature is None: 

73 self._active_sensitive_feature = self.sensitive_features.columns[0] 

74 return self._active_sensitive_feature 

75 

76 @active_sens_feat.setter 

77 def active_sens_feat(self, value: str): 

78 """ 

79 Sets the active_sens_feat value. 

80 

81 Parameters 

82 ---------- 

83 value : str 

84 Name of the sensitive feature column an evaluator has to operate on. 

85 """ 

86 self._active_sensitive_feature = value 

87 

88 @property 

89 def sensitive_feature(self): 

90 """ 

91 Reveals the sensitive feature defined by active_sens_feat. 

92 

93 This is generally called from within an evaluator, when it is working 

94 on a single sensitive feature. 

95 """ 

96 return self.sensitive_features[self.active_sens_feat] 

97 

98 @property 

99 def y_type(self): 

100 return type_of_target(self.y) 

101 

102 @property 

103 def data(self): 

104 data = {"X": self.X, "y": self.y, "sensitive_features": self.sensitive_features} 

105 return data 

106 

107 def _process_inputs(self, sensitive_intersections): 

108 if self.X is not None: 

109 self.X = self._process_X(self.X) 

110 if self.y is not None: 

111 self.y = self._process_y(self.y) 

112 if self.sensitive_features is not None: 

113 self.sensitive_features = self._process_sensitive( 

114 deepcopy(self.sensitive_features), sensitive_intersections 

115 ) 

116 

117 def _process_sensitive(self, sensitive_features, sensitive_intersections): 

118 """ 

119 Formats sensitive features 

120 

121 Parameters 

122 ---------- 

123 sensitive_features : 

124 Sensitive features as provided by a user. Any format that can be constrained 

125 in a dataframe is accepted. 

126 sensitive_intersections : Bool 

127 Indicates whether to create intersections among sensitive features. 

128 

129 Returns 

130 ------- 

131 _type_ 

132 _description_ 

133 """ 

134 df = pd.DataFrame(sensitive_features) 

135 # add intersections if asked for 

136 features = df.columns 

137 if sensitive_intersections is False or len(features) == 1: 

138 return df 

139 elif sensitive_intersections is True: 

140 sensitive_intersections = features 

141 intersections = [] 

142 for i in range(2, len(features) + 1): 

143 intersections += list(itertools.combinations(sensitive_intersections, i)) 

144 for intersection in intersections: 

145 tmp = df[intersection[0]] 

146 for col in intersection[1:]: 

147 tmp = tmp.str.cat(df[col].astype(str), sep="_") 

148 label = "_".join(intersection) 

149 df[label] = tmp 

150 return df 

151 

152 def _process_X(self, X): 

153 return X 

154 

155 def _process_y(self, y): 

156 return y 

157 

158 def _validate_inputs(self): 

159 """Basic input validation""" 

160 if self.X is not None: 

161 self._validate_X() 

162 if self.y is not None: 

163 self._validate_y() 

164 if self.sensitive_features is not None: 

165 self._validate_sensitive() 

166 

167 def _validate_sensitive(self): 

168 """Sensitive features validation""" 

169 # Validate the types 

170 if not isinstance(self.sensitive_features, (pd.Series, pd.DataFrame)): 

171 raise ValidationError( 

172 "Sensitive_feature type is '" 

173 + type(self.sensitive_features).__name__ 

174 + "' but the required type is either pd.DataFrame or pd.Series" 

175 ) 

176 if self.X is not None: 

177 if len(self.X) != len(self.sensitive_features): 

178 raise ValidationError( 

179 "X and sensitive_features are not the same length. " 

180 + f"X Length: {len(self.X)}, sensitive_features Length: {len(self.y)}" 

181 ) 

182 if isinstance(self.X, (pd.Series, pd.DataFrame)) and not self.X.index.equals( 

183 self.sensitive_features.index 

184 ): 

185 raise ValidationError("X and sensitive features must have the same index") 

186 

187 if isinstance(self.sensitive_features, pd.Series): 

188 if not hasattr(self.sensitive_features, "name"): 

189 raise ValidationError("Feature Series should have a name attribute") 

190 

191 @abstractmethod 

192 def _validate_X(self): 

193 pass 

194 

195 @abstractmethod 

196 def _validate_y(self): 

197 pass 

198 

199 def _validate_processing(self): 

200 """Validation of processed data""" 

201 if self.X is not None: 

202 self._validate_processed_X() 

203 if self.y is not None: 

204 self._validate_processed_y() 

205 if self.sensitive_features is not None: 

206 self._validate_processed_sensitive() 

207 

208 def _validate_processed_X(self): 

209 pass 

210 

211 def _validate_processed_y(self): 

212 pass 

213 

214 def _validate_processed_sensitive(self): 

215 """Validation of processed sensitive features""" 

216 for col_name, col in self.sensitive_features.iteritems(): 

217 unique_values = col.unique() 

218 if len(unique_values) == 1: 

219 raise ValidationError( 

220 f"Sensitive Feature column {col_name} must have more " 

221 f"than one unique value. Only found one value: {unique_values[0]}" 

222 )