Coverage for credoai/artifacts/model/base_model.py: 95%

42 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-13 21:56 +0000

1"""Abstract class for model artifacts used by `Lens`""" 

2from abc import ABC 

3from typing import List, Optional 

4 

5from credoai.utils import ValidationError 

6from credoai.utils.model_utils import get_model_info 

7 

8 

9class Model(ABC): 

10 """Base class for all models in Lens. 

11 

12 Parameters 

13 ---------- 

14 type : str, optional 

15 Type of the model 

16 possible_functions: List[str] 

17 List of possible methods that can be used by a model 

18 necessary_functions: List[str] 

19 List of necessary methods for the model type 

20 name: str 

21 Class name. 

22 model_like : model_like 

23 A model or pipeline. 

24 tags : dictionary, optional 

25 Additional metadata to add to model 

26 E.g., {'model_type': 'binary_classification'} 

27 

28 """ 

29 

30 def __init__( 

31 self, 

32 type: str, 

33 possible_functions: List[str], 

34 necessary_functions: List[str], 

35 name: str, 

36 model_like, 

37 tags: Optional[dict] = None, 

38 ): 

39 self.type = type 

40 self.name = name 

41 self.model_like = model_like 

42 self.tags = tags or {} 

43 self._process_model(model_like, necessary_functions, possible_functions) 

44 self.__post_init__() 

45 

46 def __post_init__(self): 

47 """Optional custom functionality to call after Base Model init""" 

48 pass 

49 

50 @property 

51 def tags(self): 

52 return self._tags 

53 

54 @tags.setter 

55 def tags(self, value): 

56 if not isinstance(value, dict) and value is not None: 

57 raise ValidationError("Tags must be of type dictionary") 

58 self._tags = value 

59 

60 def _process_model(self, model_like, necessary_functions, possible_functions): 

61 self.model_info = get_model_info(model_like) 

62 self._validate_framework() 

63 self._validate_callables(necessary_functions) 

64 self._build(possible_functions) 

65 return self 

66 

67 def _build(self, function_names: List[str]): 

68 """ 

69 Makes the necessary methods available in the class 

70 

71 Parameters 

72 ---------- 

73 function_names : List[str] 

74 List of possible methods to be imported from model_like 

75 """ 

76 for key in function_names: 

77 self._add_functionality(key) 

78 

79 def _validate_framework(self): 

80 """ 

81 Optional check to determine whether model is from supported framework. 

82 WARNS if unsupported framework (does not RAISE). 

83 """ 

84 pass 

85 

86 def _validate_callables(self, function_names: List[str]): 

87 """ 

88 Checks that the necessary methods are available in model_like 

89 

90 Parameters 

91 ---------- 

92 function_names : List[str] 

93 List of necessary functions 

94 

95 Raises 

96 ------ 

97 ValidationError 

98 If a necessary method is missing from model_like 

99 """ 

100 for key in function_names: 

101 validated = getattr(self.model_like, key, False) 

102 if not validated: 

103 raise ValidationError(f"Model-like must have a {key} function") 

104 

105 def _add_functionality(self, key: str): 

106 """Adds functionality from model_like, if it exists""" 

107 func = getattr(self.model_like, key, None) 

108 if func: 

109 self.__dict__[key] = func