Coverage for credoai/artifacts/data/tabular_data.py: 85%
40 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-08 07:32 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-08 07:32 +0000
1"""Data artifact wrapping any data in table format"""
2from copy import deepcopy
3from typing import Union
5import numpy as np
6import pandas as pd
8from credoai.utils.common import ValidationError, check_array_like
10from .base_data import Data
13class TabularData(Data):
14 """Class wrapper around tabular data
16 TabularData serves as an adapter between tabular datasets
17 and the evaluators in Lens. TabularData processes X
19 Parameters
20 -------------
21 name : str
22 Label of the dataset
23 X : array-like of shape (n_samples, n_features)
24 Dataset. Must be able to be transformed into a pandas DataFrame
25 y : array-like of shape (n_samples, n_outputs)
26 Outcome
27 sensitive_features : pd.Series, pd.DataFrame, optional
28 Sensitive Features, which will be used for disaggregating performance
29 metrics. This can be the feature you want to perform segmentation analysis on, or
30 a feature related to fairness like 'race' or 'gender'. Sensitive Features *must*
31 be categorical features.
32 sensitive_intersections : bool, list
33 Whether to add intersections of sensitive features. If True, add all possible
34 intersections. If list, only create intersections from specified sensitive features.
35 If False, no intersections will be created. Defaults False
36 """
38 def __init__(
39 self,
40 name: str,
41 X=None,
42 y=None,
43 sensitive_features=None,
44 sensitive_intersections: Union[bool, list] = False,
45 ):
46 super().__init__(
47 "Tabular", name, X, y, sensitive_features, sensitive_intersections
48 )
50 def copy(self):
51 """Returns a deepcopy of the instantiated class"""
52 return deepcopy(self)
54 def _process_X(self, X):
55 """Standardize X data
57 Ensures X is a dataframe with string-named columns
58 """
59 temp = pd.DataFrame(X)
60 # Column names are converted to strings, to avoid mixed types
61 temp.columns = temp.columns.astype("str")
62 return temp
64 def _process_y(self, y):
65 """Standardize y data
67 If y is convertible, convert y to pandas object with X's index
68 """
69 if isinstance(y, (pd.DataFrame, pd.Series)):
70 return y
71 pd_type = pd.Series
72 if isinstance(y, np.ndarray) and y.ndim == 2 and y.shape[1] > 1:
73 pd_type = pd.DataFrame
74 if self.X is not None:
75 y = pd_type(y, index=self.X.index)
76 else:
77 y = pd_type(y)
78 y.name = "target"
79 return y
81 def _validate_X(self):
82 check_array_like(self.X)
84 def _validate_y(self):
85 """Validation of Y inputs"""
86 check_array_like(self.y)
87 if self.X is not None and (len(self.X) != len(self.y)):
88 raise ValidationError(
89 "X and y are not the same length. "
90 + f"X Length: {len(self.X)}, y Length: {len(self.y)}"
91 )
93 def _validate_processed_X(self):
94 """Validate processed X"""
95 if len(self.X.columns) != len(set(self.X.columns)):
96 raise ValidationError("X contains duplicate column names")
97 if len(self.X.index) != len(set(self.X.index)):
98 raise ValidationError("X's index cannot contain duplicates")
100 def _validate_processed_y(self):
101 """Validate processed Y"""
102 if isinstance(self.X, (pd.Series, pd.DataFrame)) and not self.X.index.equals(
103 self.y.index
104 ):
105 raise ValidationError("X and y must have the same index")