Coverage for credoai/artifacts/data/base_data.py: 89%

116 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-13 21:56 +0000

1"""Abstract class for the data artifacts used by `Lens`""" 

2# Data is a lightweight wrapper that stores data 

3import itertools 

4from copy import deepcopy 

5from typing import Optional, Union 

6 

7import pandas as pd 

8 

9from credoai.utils import global_logger 

10from credoai.utils.common import ValidationError, check_pandas 

11from credoai.utils.model_utils import type_of_target 

12 

13 

14class Data: 

15 """Class wrapper around data-to-be-assessed 

16 

17 Data is passed to Lens for certain assessments. 

18 

19 Data serves as an adapter between datasets 

20 and the evaluators in Lens. 

21 

22 Parameters 

23 ------------- 

24 type : str 

25 Type of the dataset 

26 name : str 

27 Label of the dataset 

28 X : to-be-defined by children 

29 Dataset 

30 y : to-be-defined by children 

31 Outcome 

32 sensitive_features : pd.Series, pd.DataFrame, optional 

33 Sensitive Features, which will be used for disaggregating performance 

34 metrics. This can be the columns you want to perform segmentation analysis on, or 

35 a feature related to fairness like 'race' or 'gender' 

36 sensitive_intersections : bool, list 

37 Whether to add intersections of sensitive features. If True, add all possible 

38 intersections. If list, only create intersections from specified sensitive features. 

39 If False, no intersections will be created. Defaults False 

40 """ 

41 

42 def __init__( 

43 self, 

44 type: str, 

45 name: str, 

46 X=None, 

47 y=None, 

48 sensitive_features=None, 

49 sensitive_intersections: Union[bool, list] = False, 

50 ): 

51 if isinstance(name, str): 

52 self.name = name 

53 else: 

54 raise ValidationError("{Name} must be a string") 

55 self.X = X 

56 self.y = y 

57 self.sensitive_features = sensitive_features 

58 self._validate_inputs() 

59 self._process_inputs(sensitive_intersections) 

60 self._validate_processing() 

61 self._active_sensitive_feature: Optional[str] = None 

62 

63 @property 

64 def active_sens_feat(self): 

65 """ 

66 Defines which sensitive feature an evaluator will be working on. 

67 

68 In combination with the property sensitive_feature this effectively creates 

69 a view of a specific artifact. 

70 """ 

71 if self._active_sensitive_feature is None: 

72 self._active_sensitive_feature = self.sensitive_features.columns[0] 

73 return self._active_sensitive_feature 

74 

75 @active_sens_feat.setter 

76 def active_sens_feat(self, value: str): 

77 """ 

78 Sets the active_sens_feat value. 

79 

80 Parameters 

81 ---------- 

82 value : str 

83 Name of the sensitive feature column an evaluator has to operate on. 

84 """ 

85 self._active_sensitive_feature = value 

86 

87 @property 

88 def sensitive_feature(self): 

89 """ 

90 Reveals the sensitive feature defined by active_sens_feat. 

91 

92 This is generally called from within an evaluator, when it is working 

93 on a single sensitive feature. 

94 """ 

95 return self.sensitive_features[self.active_sens_feat] 

96 

97 @property 

98 def y_type(self): 

99 return type_of_target(self.y) 

100 

101 @property 

102 def data(self): 

103 data = {"X": self.X, "y": self.y, "sensitive_features": self.sensitive_features} 

104 return data 

105 

106 def _process_inputs(self, sensitive_intersections): 

107 if self.X is not None: 

108 self.X = self._process_X(self.X) 

109 if self.y is not None: 

110 self.y = self._process_y(self.y) 

111 if self.sensitive_features is not None: 

112 self.sensitive_features = self._process_sensitive( 

113 self.sensitive_features, sensitive_intersections 

114 ) 

115 

116 def _process_sensitive(self, sensitive_features, sensitive_intersections): 

117 """ 

118 Formats sensitive features 

119 

120 Parameters 

121 ---------- 

122 sensitive_features : 

123 Sensitive features as provided by a user. Any format that can be constrained 

124 in a dataframe is accepted. 

125 sensitive_intersections : Bool 

126 Indicates whether to create intersections among sensitive features. 

127 

128 Returns 

129 ------- 

130 pd.DataFrame 

131 dataframe of processed sensitive features 

132 """ 

133 df = deepcopy(pd.DataFrame(sensitive_features)) 

134 if len(df.columns) == 1 and isinstance(df.columns[0], int): 

135 df.columns = ["NA"] 

136 # add intersections if asked for 

137 features = df.columns 

138 if sensitive_intersections is False or len(features) == 1: 

139 return df 

140 elif sensitive_intersections is True: 

141 sensitive_intersections = features 

142 intersections = [] 

143 for i in range(2, len(features) + 1): 

144 intersections += list(itertools.combinations(sensitive_intersections, i)) 

145 for intersection in intersections: 

146 tmp = df[intersection[0]] 

147 for col in intersection[1:]: 

148 tmp = tmp.str.cat(df[col].astype(str), sep="_") 

149 label = "_".join(intersection) 

150 df[label] = tmp 

151 return df 

152 

153 def _process_X(self, X): 

154 return X 

155 

156 def _process_y(self, y): 

157 return y 

158 

159 def _validate_inputs(self): 

160 """Basic input validation""" 

161 if self.X is not None: 

162 self._validate_X() 

163 if self.y is not None: 

164 self._validate_y() 

165 if self.sensitive_features is not None: 

166 self._validate_sensitive() 

167 

168 def _validate_sensitive(self): 

169 """Sensitive features validation""" 

170 # Validate the types 

171 if not check_pandas(self.sensitive_features): 

172 raise ValidationError( 

173 "Sensitive_feature type is '" 

174 + type(self.sensitive_features).__name__ 

175 + "' but the required type is either pd.DataFrame or pd.Series" 

176 ) 

177 if self.X is not None: 

178 if len(self.X) != len(self.sensitive_features): 

179 raise ValidationError( 

180 "X and sensitive_features are not the same length. " 

181 + f"X Length: {len(self.X)}, sensitive_features Length: {len(self.y)}" 

182 ) 

183 if check_pandas(self.X) and not self.X.index.equals( 

184 self.sensitive_features.index 

185 ): 

186 raise ValidationError("X and sensitive features must have the same index") 

187 

188 if isinstance(self.sensitive_features, pd.Series): 

189 if not hasattr(self.sensitive_features, "name"): 

190 raise ValidationError( 

191 "Sensitive Feature Series should have a name attribute" 

192 ) 

193 

194 if not self.sensitive_features.index.is_unique: 

195 raise ValidationError("Sensitive Features index must be unique") 

196 

197 def _validate_X(self): 

198 pass 

199 

200 def _validate_y(self): 

201 pass 

202 

203 def _validate_processing(self): 

204 """Validation of processed data""" 

205 if self.X is not None: 

206 self._validate_processed_X() 

207 if self.y is not None: 

208 self._validate_processed_y() 

209 if self.sensitive_features is not None: 

210 self._validate_processed_sensitive() 

211 

212 def _validate_processed_X(self): 

213 pass 

214 

215 def _validate_processed_y(self): 

216 pass 

217 

218 def _validate_processed_sensitive(self): 

219 """Validation of processed sensitive features""" 

220 for col_name, col in self.sensitive_features.iteritems(): 

221 # validate unique 

222 unique_values = col.unique() 

223 if len(unique_values) == 1: 

224 raise ValidationError( 

225 f"Sensitive Feature column {col_name} must have more " 

226 f"than one unique value. Only found one value: {unique_values[0]}" 

227 ) 

228 # validate number in each group 

229 for group, value in col.value_counts().iteritems(): 

230 if value < 10: 

231 global_logger.warning( 

232 f"Dataset Issue! Very few ({value}) records were found for {group} under sensitive feature {col_name}." 

233 ) 

234 # validate variance in y 

235 if self.y is not None: 

236 y = pd.DataFrame(self.y) 

237 for outcome, outcome_col in y.iteritems(): 

238 for group, value in outcome_col.groupby(col).std().iteritems(): 

239 if value == 0: 

240 global_logger.warning( 

241 "%s\n%s", 

242 f"Dataset Issue! Zero variance in the outcome ({outcome}) detected for {group} under sensitive feature {col_name}.", 

243 "\tDownstream evaluators may fail or not perform as expected.", 

244 )