Coverage for credoai/prism/prism.py: 89%
36 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-13 21:56 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-13 21:56 +0000
1from typing import List
3from credoai.lens import Lens
4from credoai.prism.task import Task
5from credoai.utils import ValidationError, global_logger
6from credoai.utils.common import NotRunError
9class Prism:
10 """
11 **Experimental**
13 Orchestrates the run of complex operations (Tasks) on
14 an arbitrary amount of Lens objects.
16 Parameters
17 ----------
18 lenses : List[Lens]
19 A list of Lens objects. The only requirement is for the Lens objects to
20 be instantiated with their necessary artifacts. One or multiple Lens objects
21 can be provided, each Task will validate that the amount of objects provided
22 is suitable for its requirement.
23 task : Task
24 A task instance, instantiated with all the required parameters.
25 """
27 def __init__(self, lenses: List[Lens], task: Task):
28 self.lenses = lenses
29 self.task = task
30 self._validate()
31 self.run_flag = False
32 self.compare_results: List = []
33 self.results: List = []
35 def _validate(self):
36 """
37 Validate Prism parameters.
39 Raises
40 ------
41 ValidationError
43 """
44 for step in self.lenses:
45 if not isinstance(step, Lens):
46 raise ValidationError("Step must be a Lens instance")
47 if not isinstance(self.task, Task):
48 raise ValidationError(
49 "The parameter task should be an instance of credoai.prism.Task"
50 )
52 def _run(self):
53 """
54 Runs the Lens pipelines if they were not already.
55 """
56 for step in self.lenses:
57 try:
58 step.get_results()
59 global_logger.info(f"{step.model.name} pipeline already run")
60 except NotRunError:
61 step.run()
62 global_logger.info("Running step")
64 self.run_flag = True
66 def execute(self):
67 """
68 Executes the task run.
69 """
70 # Check if already executed
71 if not self.run_flag:
72 self._run()
74 self.results.append(self.task(pipelines=self.lenses).run().get_results())
76 def get_results(self):
77 """
78 Returns prism results.
79 """
80 return self.results
82 def get_pipelines_results(self):
83 """
84 Returns individual results of all the Lens objects runs.
85 """
86 return [x.get_results() for x in self.lenses]