Coverage for credoai/artifacts/data/tabular_data.py: 86%

42 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-13 21:56 +0000

1"""Data artifact wrapping any data in table format""" 

2from copy import deepcopy 

3from typing import Union 

4 

5import numpy as np 

6import pandas as pd 

7 

8from credoai.utils.common import ValidationError, check_array_like, check_pandas 

9 

10from .base_data import Data 

11 

12 

13class TabularData(Data): 

14 """Class wrapper around tabular data 

15 

16 TabularData serves as an adapter between tabular datasets 

17 and the evaluators in Lens. TabularData processes X 

18 

19 Parameters 

20 ------------- 

21 name : str 

22 Label of the dataset 

23 X : array-like of shape (n_samples, n_features) 

24 Dataset. Must be able to be transformed into a pandas DataFrame 

25 y : array-like of shape (n_samples, n_outputs) 

26 Outcome 

27 sensitive_features : pd.Series, pd.DataFrame, optional 

28 Sensitive Features, which will be used for disaggregating performance 

29 metrics. This can be the feature you want to perform segmentation analysis on, or 

30 a feature related to fairness like 'race' or 'gender'. Sensitive Features *must* 

31 be categorical features. 

32 sensitive_intersections : bool, list 

33 Whether to add intersections of sensitive features. If True, add all possible 

34 intersections. If list, only create intersections from specified sensitive features. 

35 If False, no intersections will be created. Defaults False 

36 """ 

37 

38 def __init__( 

39 self, 

40 name: str, 

41 X=None, 

42 y=None, 

43 sensitive_features=None, 

44 sensitive_intersections: Union[bool, list] = False, 

45 ): 

46 super().__init__( 

47 "Tabular", name, X, y, sensitive_features, sensitive_intersections 

48 ) 

49 

50 def copy(self): 

51 """Returns a deepcopy of the instantiated class""" 

52 return deepcopy(self) 

53 

54 def _process_X(self, X): 

55 """Standardize X data 

56 

57 Ensures X is a dataframe with string-named columns 

58 """ 

59 temp = pd.DataFrame(X) 

60 # Column names are converted to strings, to avoid mixed types 

61 temp.columns = temp.columns.astype("str") 

62 # if X was not a pandas object, give it the index of sensitive features 

63 if not check_pandas(X) and self.sensitive_features is not None: 

64 temp.index = self.sensitive_features.index 

65 return temp 

66 

67 def _process_y(self, y): 

68 """Standardize y data 

69 

70 If y is convertible, convert y to pandas object with X's index 

71 """ 

72 if isinstance(y, (pd.DataFrame, pd.Series)): 

73 return y 

74 pd_type = pd.Series 

75 if isinstance(y, np.ndarray) and y.ndim == 2 and y.shape[1] > 1: 

76 pd_type = pd.DataFrame 

77 if self.X is not None: 

78 y = pd_type(y, index=self.X.index) 

79 else: 

80 y = pd_type(y) 

81 y.name = "target" 

82 return y 

83 

84 def _validate_X(self): 

85 check_array_like(self.X) 

86 

87 def _validate_y(self): 

88 """Validation of Y inputs""" 

89 check_array_like(self.y) 

90 if self.X is not None and (len(self.X) != len(self.y)): 

91 raise ValidationError( 

92 "X and y are not the same length. " 

93 + f"X Length: {len(self.X)}, y Length: {len(self.y)}" 

94 ) 

95 

96 def _validate_processed_X(self): 

97 """Validate processed X""" 

98 if len(self.X.columns) != len(set(self.X.columns)): 

99 raise ValidationError("X contains duplicate column names") 

100 if not self.X.index.is_unique: 

101 raise ValidationError("X's index must be unique") 

102 

103 def _validate_processed_y(self): 

104 """Validate processed Y""" 

105 if isinstance(self.X, (pd.Series, pd.DataFrame)) and not self.X.index.equals( 

106 self.y.index 

107 ): 

108 raise ValidationError("X and y must have the same index")