Coverage for credoai/prism/prism.py: 89%

36 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-13 21:56 +0000

1from typing import List 

2 

3from credoai.lens import Lens 

4from credoai.prism.task import Task 

5from credoai.utils import ValidationError, global_logger 

6from credoai.utils.common import NotRunError 

7 

8 

9class Prism: 

10 """ 

11 **Experimental** 

12 

13 Orchestrates the run of complex operations (Tasks) on 

14 an arbitrary amount of Lens objects. 

15 

16 Parameters 

17 ---------- 

18 lenses : List[Lens] 

19 A list of Lens objects. The only requirement is for the Lens objects to 

20 be instantiated with their necessary artifacts. One or multiple Lens objects 

21 can be provided, each Task will validate that the amount of objects provided 

22 is suitable for its requirement. 

23 task : Task 

24 A task instance, instantiated with all the required parameters. 

25 """ 

26 

27 def __init__(self, lenses: List[Lens], task: Task): 

28 self.lenses = lenses 

29 self.task = task 

30 self._validate() 

31 self.run_flag = False 

32 self.compare_results: List = [] 

33 self.results: List = [] 

34 

35 def _validate(self): 

36 """ 

37 Validate Prism parameters. 

38 

39 Raises 

40 ------ 

41 ValidationError 

42 

43 """ 

44 for step in self.lenses: 

45 if not isinstance(step, Lens): 

46 raise ValidationError("Step must be a Lens instance") 

47 if not isinstance(self.task, Task): 

48 raise ValidationError( 

49 "The parameter task should be an instance of credoai.prism.Task" 

50 ) 

51 

52 def _run(self): 

53 """ 

54 Runs the Lens pipelines if they were not already. 

55 """ 

56 for step in self.lenses: 

57 try: 

58 step.get_results() 

59 global_logger.info(f"{step.model.name} pipeline already run") 

60 except NotRunError: 

61 step.run() 

62 global_logger.info("Running step") 

63 

64 self.run_flag = True 

65 

66 def execute(self): 

67 """ 

68 Executes the task run. 

69 """ 

70 # Check if already executed 

71 if not self.run_flag: 

72 self._run() 

73 

74 self.results.append(self.task(pipelines=self.lenses).run().get_results()) 

75 

76 def get_results(self): 

77 """ 

78 Returns prism results. 

79 """ 

80 return self.results 

81 

82 def get_pipelines_results(self): 

83 """ 

84 Returns individual results of all the Lens objects runs. 

85 """ 

86 return [x.get_results() for x in self.lenses]