Coverage for credoai/artifacts/data/comparison_data.py: 79%

63 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-13 21:56 +0000

1"""Data artifact for pair-wise-comparison-based identity verification""" 

2from copy import deepcopy 

3 

4import pandas as pd 

5 

6from credoai.utils.common import ValidationError 

7 

8from .base_data import Data 

9 

10 

11class ComparisonData(Data): 

12 """Class wrapper for pair-wise-comparison-based identity verification 

13 

14 ComparisonData serves as an adapter between pair-wise-comparison-based identity verification 

15 and the identity verification evaluator in Lens. 

16 

17 Parameters 

18 ------------- 

19 name : str 

20 Label of the dataset 

21 pairs : pd.DataFrame of shape (n_pairs, 4) 

22 Dataframe where each row represents a data sample pair and associated subjects 

23 Type of data sample is decided by the ComparisonModel's `compare` function, which takes 

24 data sample pairs and returns their similarity scores. Examples are selfies, fingerprint scans, 

25 or voices of a person. 

26 Required columns: 

27 source-subject-id: unique identifier of the source subject 

28 source-subject-data-sample: data sample from the source subject 

29 target-subject-id: unique identifier of the target subject 

30 target-subject-data-sample: data sample from the target subject 

31 subjects_sensitive_features : pd.DataFrame of shape (n_subjects, n_sensitive_feature_names), optional 

32 Sensitive features of all subjects present in pairs dataframe 

33 If provided, disaggregated performance assessment is also performed. 

34 This can be the columns you want to perform segmentation analysis on, or 

35 a feature related to fairness like 'race' or 'gender' 

36 Required columns: 

37 subject-id: id of subjects. Must cover all the subjects inlcluded in `pairs` dataframe 

38 other columns with arbitrary names for sensitive features 

39 """ 

40 

41 def __init__(self, name: str, pairs=None, subjects_sensitive_features=None): 

42 super().__init__("ComparisonData", name) 

43 self.pairs = pairs 

44 self.subjects_sensitive_features = subjects_sensitive_features 

45 self._validate_pairs() 

46 self._validate_subjects_sensitive_features() 

47 self._preprocess_pairs() 

48 self._preprocess_subjects_sensitive_features() 

49 self._validate_pairs_subjects_sensitive_features_match() 

50 

51 def copy(self): 

52 """Returns a deepcopy of the instantiated class""" 

53 return deepcopy(self) 

54 

55 def _validate_pairs(self): 

56 """Validate the input `pairs` object""" 

57 if self.pairs is not None: 

58 # Basic validation for pairs 

59 if not isinstance(self.pairs, (pd.DataFrame)): 

60 raise ValidationError("pairs must be a pd.DataFrame") 

61 

62 required_columns = [ 

63 "source-subject-id", 

64 "source-subject-data-sample", 

65 "target-subject-id", 

66 "target-subject-data-sample", 

67 ] 

68 available_columns = self.pairs.columns 

69 for c in required_columns: 

70 if c not in available_columns: 

71 raise ValidationError( 

72 f"pairs dataframe does not contain the required column '{c}'" 

73 ) 

74 

75 if len(available_columns) != 4: 

76 raise ValidationError( 

77 f"pairs dataframe has '{len(available_columns)}' columns. It must have 4." 

78 ) 

79 

80 if self.pairs.isnull().values.any(): 

81 raise ValidationError( 

82 "pairs dataframe contains NaN values. It must not have any." 

83 ) 

84 

85 def _validate_subjects_sensitive_features(self): 

86 """Validate the input `subjects_sensitive_features` object""" 

87 if self.subjects_sensitive_features is not None: 

88 # Basic validation for subjects_sensitive_features 

89 if not isinstance(self.subjects_sensitive_features, (pd.DataFrame)): 

90 raise ValidationError( 

91 "subjects_sensitive_features must be a pd.DataFrame" 

92 ) 

93 

94 available_columns = self.subjects_sensitive_features.columns 

95 if "subject-id" not in available_columns: 

96 raise ValidationError( 

97 "subjects_sensitive_features dataframe does not contain the required column 'subject-id'" 

98 ) 

99 if len(available_columns) < 2: 

100 raise ValidationError( 

101 "subjects_sensitive_features dataframe includes 'subject-id' column only. It must include at least one sensitive feature column too." 

102 ) 

103 

104 if self.subjects_sensitive_features.isnull().values.any(): 

105 raise ValidationError( 

106 "subjects_sensitive_features dataframe contains NaN values. It must not have any." 

107 ) 

108 

109 sensitive_features_names = list(self.subjects_sensitive_features.columns) 

110 sensitive_features_names.remove("subject-id") 

111 for sf_name in sensitive_features_names: 

112 unique_values = self.subjects_sensitive_features[sf_name].unique() 

113 if len(unique_values) == 1: 

114 raise ValidationError( 

115 f"Sensitive Feature column {sf_name} must have more " 

116 f"than one unique value. Only found one value: {unique_values[0]}" 

117 ) 

118 

119 def _preprocess_pairs(self): 

120 """Preprocess the input `pairs` object""" 

121 cols = ["source-subject-id", "target-subject-id"] 

122 self.pairs[cols] = self.pairs[cols].astype(str) 

123 

124 def _preprocess_subjects_sensitive_features(self): 

125 """Preprocess the input `subjects_sensitive_features` object""" 

126 if self.subjects_sensitive_features is not None: 

127 self.subjects_sensitive_features = self.subjects_sensitive_features.astype( 

128 str 

129 ) 

130 

131 def _validate_pairs_subjects_sensitive_features_match(self): 

132 if self.subjects_sensitive_features is not None: 

133 subjects_in_pairs = list( 

134 pd.unique( 

135 self.pairs[["source-subject-id", "target-subject-id"]].values.ravel( 

136 "K" 

137 ) 

138 ) 

139 ) 

140 subjects_in_subjects_sensitive_features = list( 

141 self.subjects_sensitive_features["subject-id"].unique() 

142 ) 

143 missing_ids = set(subjects_in_pairs) - set( 

144 subjects_in_subjects_sensitive_features 

145 ) 

146 if len(missing_ids) > 0: 

147 raise ValidationError( 

148 f"Some subject-id s that exist in the input `pairs` object do not exist in the input `subjects_sensitive_features` object." 

149 f"These inclide {missing_ids}." 

150 ) 

151 

152 def _validate_X(self): 

153 pass 

154 

155 def _validate_y(self): 

156 pass