Coverage for credoai/artifacts/data/comparison_data.py: 79%
63 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-13 21:56 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-13 21:56 +0000
1"""Data artifact for pair-wise-comparison-based identity verification"""
2from copy import deepcopy
4import pandas as pd
6from credoai.utils.common import ValidationError
8from .base_data import Data
11class ComparisonData(Data):
12 """Class wrapper for pair-wise-comparison-based identity verification
14 ComparisonData serves as an adapter between pair-wise-comparison-based identity verification
15 and the identity verification evaluator in Lens.
17 Parameters
18 -------------
19 name : str
20 Label of the dataset
21 pairs : pd.DataFrame of shape (n_pairs, 4)
22 Dataframe where each row represents a data sample pair and associated subjects
23 Type of data sample is decided by the ComparisonModel's `compare` function, which takes
24 data sample pairs and returns their similarity scores. Examples are selfies, fingerprint scans,
25 or voices of a person.
26 Required columns:
27 source-subject-id: unique identifier of the source subject
28 source-subject-data-sample: data sample from the source subject
29 target-subject-id: unique identifier of the target subject
30 target-subject-data-sample: data sample from the target subject
31 subjects_sensitive_features : pd.DataFrame of shape (n_subjects, n_sensitive_feature_names), optional
32 Sensitive features of all subjects present in pairs dataframe
33 If provided, disaggregated performance assessment is also performed.
34 This can be the columns you want to perform segmentation analysis on, or
35 a feature related to fairness like 'race' or 'gender'
36 Required columns:
37 subject-id: id of subjects. Must cover all the subjects inlcluded in `pairs` dataframe
38 other columns with arbitrary names for sensitive features
39 """
41 def __init__(self, name: str, pairs=None, subjects_sensitive_features=None):
42 super().__init__("ComparisonData", name)
43 self.pairs = pairs
44 self.subjects_sensitive_features = subjects_sensitive_features
45 self._validate_pairs()
46 self._validate_subjects_sensitive_features()
47 self._preprocess_pairs()
48 self._preprocess_subjects_sensitive_features()
49 self._validate_pairs_subjects_sensitive_features_match()
51 def copy(self):
52 """Returns a deepcopy of the instantiated class"""
53 return deepcopy(self)
55 def _validate_pairs(self):
56 """Validate the input `pairs` object"""
57 if self.pairs is not None:
58 # Basic validation for pairs
59 if not isinstance(self.pairs, (pd.DataFrame)):
60 raise ValidationError("pairs must be a pd.DataFrame")
62 required_columns = [
63 "source-subject-id",
64 "source-subject-data-sample",
65 "target-subject-id",
66 "target-subject-data-sample",
67 ]
68 available_columns = self.pairs.columns
69 for c in required_columns:
70 if c not in available_columns:
71 raise ValidationError(
72 f"pairs dataframe does not contain the required column '{c}'"
73 )
75 if len(available_columns) != 4:
76 raise ValidationError(
77 f"pairs dataframe has '{len(available_columns)}' columns. It must have 4."
78 )
80 if self.pairs.isnull().values.any():
81 raise ValidationError(
82 "pairs dataframe contains NaN values. It must not have any."
83 )
85 def _validate_subjects_sensitive_features(self):
86 """Validate the input `subjects_sensitive_features` object"""
87 if self.subjects_sensitive_features is not None:
88 # Basic validation for subjects_sensitive_features
89 if not isinstance(self.subjects_sensitive_features, (pd.DataFrame)):
90 raise ValidationError(
91 "subjects_sensitive_features must be a pd.DataFrame"
92 )
94 available_columns = self.subjects_sensitive_features.columns
95 if "subject-id" not in available_columns:
96 raise ValidationError(
97 "subjects_sensitive_features dataframe does not contain the required column 'subject-id'"
98 )
99 if len(available_columns) < 2:
100 raise ValidationError(
101 "subjects_sensitive_features dataframe includes 'subject-id' column only. It must include at least one sensitive feature column too."
102 )
104 if self.subjects_sensitive_features.isnull().values.any():
105 raise ValidationError(
106 "subjects_sensitive_features dataframe contains NaN values. It must not have any."
107 )
109 sensitive_features_names = list(self.subjects_sensitive_features.columns)
110 sensitive_features_names.remove("subject-id")
111 for sf_name in sensitive_features_names:
112 unique_values = self.subjects_sensitive_features[sf_name].unique()
113 if len(unique_values) == 1:
114 raise ValidationError(
115 f"Sensitive Feature column {sf_name} must have more "
116 f"than one unique value. Only found one value: {unique_values[0]}"
117 )
119 def _preprocess_pairs(self):
120 """Preprocess the input `pairs` object"""
121 cols = ["source-subject-id", "target-subject-id"]
122 self.pairs[cols] = self.pairs[cols].astype(str)
124 def _preprocess_subjects_sensitive_features(self):
125 """Preprocess the input `subjects_sensitive_features` object"""
126 if self.subjects_sensitive_features is not None:
127 self.subjects_sensitive_features = self.subjects_sensitive_features.astype(
128 str
129 )
131 def _validate_pairs_subjects_sensitive_features_match(self):
132 if self.subjects_sensitive_features is not None:
133 subjects_in_pairs = list(
134 pd.unique(
135 self.pairs[["source-subject-id", "target-subject-id"]].values.ravel(
136 "K"
137 )
138 )
139 )
140 subjects_in_subjects_sensitive_features = list(
141 self.subjects_sensitive_features["subject-id"].unique()
142 )
143 missing_ids = set(subjects_in_pairs) - set(
144 subjects_in_subjects_sensitive_features
145 )
146 if len(missing_ids) > 0:
147 raise ValidationError(
148 f"Some subject-id s that exist in the input `pairs` object do not exist in the input `subjects_sensitive_features` object."
149 f"These inclide {missing_ids}."
150 )
152 def _validate_X(self):
153 pass
155 def _validate_y(self):
156 pass