Coverage for credoai/artifacts/model/base_model.py: 94%

36 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-08 07:32 +0000

1"""Abstract class for model artifacts used by `Lens`""" 

2from abc import ABC 

3from typing import List, Optional 

4 

5from credoai.utils import ValidationError 

6from credoai.utils.model_utils import get_model_info 

7 

8 

9class Model(ABC): 

10 """Base class for all models in Lens. 

11 

12 Parameters 

13 ---------- 

14 type : str, optional 

15 Type of the model 

16 possible_functions: List[str] 

17 List of possible methods that can be used by a model 

18 necessary_functions: List[str] 

19 List of necessary methods for the model type 

20 name: str 

21 Class name. 

22 model_like : model_like 

23 A model or pipeline. 

24 

25 """ 

26 

27 def __init__( 

28 self, 

29 type: str, 

30 possible_functions: List[str], 

31 necessary_functions: List[str], 

32 name: str, 

33 model_like, 

34 tags: Optional[dict] = None, 

35 ): 

36 self.type = type 

37 self.name = name 

38 self.model_like = model_like 

39 self.tags = tags or {} 

40 self.model_info = get_model_info(model_like) 

41 self._validate(necessary_functions) 

42 self._build(possible_functions) 

43 self._update_functionality() 

44 

45 @property 

46 def tags(self): 

47 return self._tags 

48 

49 @tags.setter 

50 def tags(self, value): 

51 if not isinstance(value, dict) and value is not None: 

52 raise ValidationError("Tags must be of type dictionary") 

53 self._tags = value 

54 

55 def _build(self, function_names: List[str]): 

56 """ 

57 Makes the necessary methods available in the class 

58 

59 Parameters 

60 ---------- 

61 function_names : List[str] 

62 List of possible methods to be imported from model_like 

63 """ 

64 for key in function_names: 

65 self._add_functionality(key) 

66 

67 def _validate(self, function_names: List[str]): 

68 """ 

69 Checks the the necessary methods are available in model_like 

70 

71 Parameters 

72 ---------- 

73 function_names : List[str] 

74 List of necessary functions 

75 

76 Raises 

77 ------ 

78 ValidationError 

79 If a necessary method is missing from model_like 

80 """ 

81 for key in function_names: 

82 validated = getattr(self.model_like, key, False) 

83 if not validated: 

84 raise ValidationError(f"Model-like must have a {key} function") 

85 

86 def _add_functionality(self, key: str): 

87 """Adds functionality from model_like, if it exists""" 

88 func = getattr(self.model_like, key, None) 

89 if func: 

90 self.__dict__[key] = func 

91 

92 def _update_functionality(self): 

93 """Optional framework specific functionality update""" 

94 pass