Coverage for credoai/artifacts/data/tabular_data.py: 86%
42 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-13 21:56 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-13 21:56 +0000
1"""Data artifact wrapping any data in table format"""
2from copy import deepcopy
3from typing import Union
5import numpy as np
6import pandas as pd
8from credoai.utils.common import ValidationError, check_array_like, check_pandas
10from .base_data import Data
13class TabularData(Data):
14 """Class wrapper around tabular data
16 TabularData serves as an adapter between tabular datasets
17 and the evaluators in Lens. TabularData processes X
19 Parameters
20 -------------
21 name : str
22 Label of the dataset
23 X : array-like of shape (n_samples, n_features)
24 Dataset. Must be able to be transformed into a pandas DataFrame
25 y : array-like of shape (n_samples, n_outputs)
26 Outcome
27 sensitive_features : pd.Series, pd.DataFrame, optional
28 Sensitive Features, which will be used for disaggregating performance
29 metrics. This can be the feature you want to perform segmentation analysis on, or
30 a feature related to fairness like 'race' or 'gender'. Sensitive Features *must*
31 be categorical features.
32 sensitive_intersections : bool, list
33 Whether to add intersections of sensitive features. If True, add all possible
34 intersections. If list, only create intersections from specified sensitive features.
35 If False, no intersections will be created. Defaults False
36 """
38 def __init__(
39 self,
40 name: str,
41 X=None,
42 y=None,
43 sensitive_features=None,
44 sensitive_intersections: Union[bool, list] = False,
45 ):
46 super().__init__(
47 "Tabular", name, X, y, sensitive_features, sensitive_intersections
48 )
50 def copy(self):
51 """Returns a deepcopy of the instantiated class"""
52 return deepcopy(self)
54 def _process_X(self, X):
55 """Standardize X data
57 Ensures X is a dataframe with string-named columns
58 """
59 temp = pd.DataFrame(X)
60 # Column names are converted to strings, to avoid mixed types
61 temp.columns = temp.columns.astype("str")
62 # if X was not a pandas object, give it the index of sensitive features
63 if not check_pandas(X) and self.sensitive_features is not None:
64 temp.index = self.sensitive_features.index
65 return temp
67 def _process_y(self, y):
68 """Standardize y data
70 If y is convertible, convert y to pandas object with X's index
71 """
72 if isinstance(y, (pd.DataFrame, pd.Series)):
73 return y
74 pd_type = pd.Series
75 if isinstance(y, np.ndarray) and y.ndim == 2 and y.shape[1] > 1:
76 pd_type = pd.DataFrame
77 if self.X is not None:
78 y = pd_type(y, index=self.X.index)
79 else:
80 y = pd_type(y)
81 y.name = "target"
82 return y
84 def _validate_X(self):
85 check_array_like(self.X)
87 def _validate_y(self):
88 """Validation of Y inputs"""
89 check_array_like(self.y)
90 if self.X is not None and (len(self.X) != len(self.y)):
91 raise ValidationError(
92 "X and y are not the same length. "
93 + f"X Length: {len(self.X)}, y Length: {len(self.y)}"
94 )
96 def _validate_processed_X(self):
97 """Validate processed X"""
98 if len(self.X.columns) != len(set(self.X.columns)):
99 raise ValidationError("X contains duplicate column names")
100 if not self.X.index.is_unique:
101 raise ValidationError("X's index must be unique")
103 def _validate_processed_y(self):
104 """Validate processed Y"""
105 if isinstance(self.X, (pd.Series, pd.DataFrame)) and not self.X.index.equals(
106 self.y.index
107 ):
108 raise ValidationError("X and y must have the same index")