Coverage for credoai/artifacts/model/base_model.py: 95%
42 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-13 21:56 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-13 21:56 +0000
1"""Abstract class for model artifacts used by `Lens`"""
2from abc import ABC
3from typing import List, Optional
5from credoai.utils import ValidationError
6from credoai.utils.model_utils import get_model_info
9class Model(ABC):
10 """Base class for all models in Lens.
12 Parameters
13 ----------
14 type : str, optional
15 Type of the model
16 possible_functions: List[str]
17 List of possible methods that can be used by a model
18 necessary_functions: List[str]
19 List of necessary methods for the model type
20 name: str
21 Class name.
22 model_like : model_like
23 A model or pipeline.
24 tags : dictionary, optional
25 Additional metadata to add to model
26 E.g., {'model_type': 'binary_classification'}
28 """
30 def __init__(
31 self,
32 type: str,
33 possible_functions: List[str],
34 necessary_functions: List[str],
35 name: str,
36 model_like,
37 tags: Optional[dict] = None,
38 ):
39 self.type = type
40 self.name = name
41 self.model_like = model_like
42 self.tags = tags or {}
43 self._process_model(model_like, necessary_functions, possible_functions)
44 self.__post_init__()
46 def __post_init__(self):
47 """Optional custom functionality to call after Base Model init"""
48 pass
50 @property
51 def tags(self):
52 return self._tags
54 @tags.setter
55 def tags(self, value):
56 if not isinstance(value, dict) and value is not None:
57 raise ValidationError("Tags must be of type dictionary")
58 self._tags = value
60 def _process_model(self, model_like, necessary_functions, possible_functions):
61 self.model_info = get_model_info(model_like)
62 self._validate_framework()
63 self._validate_callables(necessary_functions)
64 self._build(possible_functions)
65 return self
67 def _build(self, function_names: List[str]):
68 """
69 Makes the necessary methods available in the class
71 Parameters
72 ----------
73 function_names : List[str]
74 List of possible methods to be imported from model_like
75 """
76 for key in function_names:
77 self._add_functionality(key)
79 def _validate_framework(self):
80 """
81 Optional check to determine whether model is from supported framework.
82 WARNS if unsupported framework (does not RAISE).
83 """
84 pass
86 def _validate_callables(self, function_names: List[str]):
87 """
88 Checks that the necessary methods are available in model_like
90 Parameters
91 ----------
92 function_names : List[str]
93 List of necessary functions
95 Raises
96 ------
97 ValidationError
98 If a necessary method is missing from model_like
99 """
100 for key in function_names:
101 validated = getattr(self.model_like, key, False)
102 if not validated:
103 raise ValidationError(f"Model-like must have a {key} function")
105 def _add_functionality(self, key: str):
106 """Adds functionality from model_like, if it exists"""
107 func = getattr(self.model_like, key, None)
108 if func:
109 self.__dict__[key] = func