Coverage for credoai/artifacts/data/tabular_data.py: 85%

40 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-08 07:32 +0000

1"""Data artifact wrapping any data in table format""" 

2from copy import deepcopy 

3from typing import Union 

4 

5import numpy as np 

6import pandas as pd 

7 

8from credoai.utils.common import ValidationError, check_array_like 

9 

10from .base_data import Data 

11 

12 

13class TabularData(Data): 

14 """Class wrapper around tabular data 

15 

16 TabularData serves as an adapter between tabular datasets 

17 and the evaluators in Lens. TabularData processes X 

18 

19 Parameters 

20 ------------- 

21 name : str 

22 Label of the dataset 

23 X : array-like of shape (n_samples, n_features) 

24 Dataset. Must be able to be transformed into a pandas DataFrame 

25 y : array-like of shape (n_samples, n_outputs) 

26 Outcome 

27 sensitive_features : pd.Series, pd.DataFrame, optional 

28 Sensitive Features, which will be used for disaggregating performance 

29 metrics. This can be the feature you want to perform segmentation analysis on, or 

30 a feature related to fairness like 'race' or 'gender'. Sensitive Features *must* 

31 be categorical features. 

32 sensitive_intersections : bool, list 

33 Whether to add intersections of sensitive features. If True, add all possible 

34 intersections. If list, only create intersections from specified sensitive features. 

35 If False, no intersections will be created. Defaults False 

36 """ 

37 

38 def __init__( 

39 self, 

40 name: str, 

41 X=None, 

42 y=None, 

43 sensitive_features=None, 

44 sensitive_intersections: Union[bool, list] = False, 

45 ): 

46 super().__init__( 

47 "Tabular", name, X, y, sensitive_features, sensitive_intersections 

48 ) 

49 

50 def copy(self): 

51 """Returns a deepcopy of the instantiated class""" 

52 return deepcopy(self) 

53 

54 def _process_X(self, X): 

55 """Standardize X data 

56 

57 Ensures X is a dataframe with string-named columns 

58 """ 

59 temp = pd.DataFrame(X) 

60 # Column names are converted to strings, to avoid mixed types 

61 temp.columns = temp.columns.astype("str") 

62 return temp 

63 

64 def _process_y(self, y): 

65 """Standardize y data 

66 

67 If y is convertible, convert y to pandas object with X's index 

68 """ 

69 if isinstance(y, (pd.DataFrame, pd.Series)): 

70 return y 

71 pd_type = pd.Series 

72 if isinstance(y, np.ndarray) and y.ndim == 2 and y.shape[1] > 1: 

73 pd_type = pd.DataFrame 

74 if self.X is not None: 

75 y = pd_type(y, index=self.X.index) 

76 else: 

77 y = pd_type(y) 

78 y.name = "target" 

79 return y 

80 

81 def _validate_X(self): 

82 check_array_like(self.X) 

83 

84 def _validate_y(self): 

85 """Validation of Y inputs""" 

86 check_array_like(self.y) 

87 if self.X is not None and (len(self.X) != len(self.y)): 

88 raise ValidationError( 

89 "X and y are not the same length. " 

90 + f"X Length: {len(self.X)}, y Length: {len(self.y)}" 

91 ) 

92 

93 def _validate_processed_X(self): 

94 """Validate processed X""" 

95 if len(self.X.columns) != len(set(self.X.columns)): 

96 raise ValidationError("X contains duplicate column names") 

97 if len(self.X.index) != len(set(self.X.index)): 

98 raise ValidationError("X's index cannot contain duplicates") 

99 

100 def _validate_processed_y(self): 

101 """Validate processed Y""" 

102 if isinstance(self.X, (pd.Series, pd.DataFrame)) and not self.X.index.equals( 

103 self.y.index 

104 ): 

105 raise ValidationError("X and y must have the same index")