1
1
import ray
2
2
3
- from com .ibm .research .ray .graph .Datamodel import OrNode
4
- from com .ibm .research .ray .graph .Datamodel import AndNode
5
- from com .ibm .research .ray .graph .Datamodel import Edge
6
- from com .ibm .research .ray .graph .Datamodel import Pipeline
7
- from com .ibm .research .ray .graph .Datamodel import XYRef
3
+ from codeflare .pipelines .Datamodel import OrNode
4
+ from codeflare .pipelines .Datamodel import AndNode
5
+ from codeflare .pipelines .Datamodel import Edge
6
+ from codeflare .pipelines .Datamodel import Pipeline
7
+ from codeflare .pipelines .Datamodel import XYRef
8
+ from codeflare .pipelines .Datamodel import Xy
8
9
9
10
import sklearn .base as base
10
11
from enum import Enum
11
12
12
13
13
14
class ExecutionType (Enum ):
14
- TRAIN = 0 ,
15
- TEST = 1
15
+ FIT = 0 ,
16
+ PREDICT = 1 ,
17
+ SCORE = 2
16
18
17
19
18
20
@ray .remote
@@ -22,87 +24,84 @@ def execute_or_node_inner(node: OrNode, train_mode: ExecutionType, Xy: XYRef):
22
24
X = ray .get (Xy .get_Xref ())
23
25
y = ray .get (Xy .get_yref ())
24
26
25
- if train_mode == ExecutionType .TRAIN :
27
+ if train_mode == ExecutionType .FIT :
26
28
if base .is_classifier (estimator ) or base .is_regressor (estimator ):
27
29
# Always clone before fit, else fit is invalid
28
30
cloned_estimator = base .clone (estimator )
29
31
cloned_estimator .fit (X , y )
30
32
# TODO: For now, make yref passthrough - this has to be fixed more comprehensively
31
33
res_Xref = ray .put (cloned_estimator .predict (X ))
32
- result = [ XYRef (res_Xref , Xy .get_yref ())]
34
+ result = XYRef (res_Xref , Xy .get_yref ())
33
35
return result
34
36
else :
35
37
# No need to clone as it is a transform pass through on the fitted estimator
36
- res_Xref = ray .put (estimator .fit_transform (X ))
37
- result = [ XYRef (res_Xref , Xy .get_yref ())]
38
+ res_Xref = ray .put (estimator .fit_transform (X , y ))
39
+ result = XYRef (res_Xref , Xy .get_yref ())
38
40
return result
39
- elif train_mode == ExecutionType .TEST :
41
+ elif train_mode == ExecutionType .SCORE :
42
+ if base .is_classifier (estimator ) or base .is_regressor (estimator ):
43
+ cloned_estimator = base .clone (estimator )
44
+ cloned_estimator .fit (X , y )
45
+ res_Xref = ray .put (cloned_estimator .score (X , y ))
46
+ result = XYRef (res_Xref , Xy .get_yref ())
47
+ return result
48
+ else :
49
+ # No need to clone as it is a transform pass through on the fitted estimator
50
+ res_Xref = ray .put (estimator .fit_transform (X , y ))
51
+ result = XYRef (res_Xref , Xy .get_yref ())
52
+ return result
53
+ elif train_mode == ExecutionType .PREDICT :
40
54
# Test mode does not clone as it is a simple predict or transform
41
55
if base .is_classifier (estimator ) or base .is_regressor (estimator ):
42
56
res_Xref = estimator .predict (X )
43
- result = [ XYRef (res_Xref , Xy .get_yref ())]
57
+ result = XYRef (res_Xref , Xy .get_yref ())
44
58
return result
45
59
else :
46
60
res_Xref = estimator .transform (X )
47
- result = [ XYRef (res_Xref , Xy .get_yref ())]
61
+ result = XYRef (res_Xref , Xy .get_yref ())
48
62
return result
49
63
50
64
51
- ###
52
- # in_args is a dict from Node to list of XYRefs
53
- ###
54
- def execute_pipeline (pipeline : Pipeline , mode : ExecutionType , in_args : dict ):
55
- nodes_by_level = pipeline .get_nodes_by_level ()
56
-
57
- # track args per edge
58
- edge_args = {}
59
- for node , node_in_args in in_args .items ():
60
- pre_edges = pipeline .get_pre_edges (node )
61
- for pre_edge in pre_edges :
62
- edge_args [pre_edge ] = node_in_args
63
-
64
- for nodes in nodes_by_level :
65
- for node in nodes :
66
- pre_edges = pipeline .get_pre_edges (node )
67
- post_edges = pipeline .get_post_edges (node )
68
- if not node .get_and_flag ():
69
- execute_or_node (node , pre_edges , edge_args , post_edges , mode )
70
- else :
71
- cross_product = execute_and_node (node , pre_edges , edge_args , post_edges )
72
- for element in cross_product :
73
- print (element )
74
-
75
- out_args = {}
76
- last_level_nodes = nodes_by_level [pipeline .compute_max_level ()]
77
- for last_level_node in last_level_nodes :
78
- edge = Edge (last_level_node , None )
79
- out_args [last_level_node ] = edge_args [edge ]
65
+ def execute_or_node (node , pre_edges , edge_args , post_edges , mode : ExecutionType ):
66
+ for pre_edge in pre_edges :
67
+ Xyref_ptrs = edge_args [pre_edge ]
68
+ exec_xyrefs = []
69
+ for xy_ref_ptr in Xyref_ptrs :
70
+ xy_ref = ray .get (xy_ref_ptr )
71
+ inner_result = execute_or_node_inner .remote (node , mode , xy_ref )
72
+ exec_xyrefs .append (inner_result )
80
73
81
- return out_args
74
+ for post_edge in post_edges :
75
+ if post_edge not in edge_args .keys ():
76
+ edge_args [post_edge ] = []
77
+ edge_args [post_edge ].extend (exec_xyrefs )
82
78
83
79
84
80
@ray .remote
85
- def and_node_eval (and_func , xy_list ):
86
- Xy = and_func .eval (xy_list )
87
- res_Xref = ray .put (Xy .get_x ())
88
- res_yref = ray .put (Xy .get_y ())
81
+ def and_node_eval (and_func , Xyref_list ):
82
+ xy_list = []
83
+ for Xyref in Xyref_list :
84
+ X = ray .get (Xyref .get_Xref ())
85
+ y = ray .get (Xyref .get_yref ())
86
+ xy_list .append (Xy (X , y ))
87
+
88
+ res_Xy = and_func .eval (xy_list )
89
+ res_Xref = ray .put (res_Xy .get_x ())
90
+ res_yref = ray .put (res_Xy .get_y ())
89
91
return XYRef (res_Xref , res_yref )
90
92
91
93
92
- def execute_and_node_inner (node : AndNode , elements ):
94
+ def execute_and_node_inner (node : AndNode , Xyref_ptrs ):
93
95
and_func = node .get_and_func ()
94
96
result = []
95
97
96
- for element in elements :
97
- xy_list = []
98
- for Xy in element :
99
- X = ray .get (Xy .get_Xref ())
100
- y = ray .get (Xy .get_yref ())
98
+ Xyref_list = []
99
+ for Xyref_ptr in Xyref_ptrs :
100
+ Xyref = ray .get (Xyref_ptr )
101
+ Xyref_list .append (Xyref )
101
102
102
- Xy = Xy (X , y )
103
- xy_list .append (Xy )
104
- Xyref = and_node_eval (and_func , xy_list )
105
- result .append (Xyref )
103
+ Xyref_ptr = and_node_eval .remote (and_func , Xyref_list )
104
+ result .append (Xyref_ptr )
106
105
return result
107
106
108
107
@@ -116,24 +115,36 @@ def execute_and_node(node, pre_edges, edge_args, post_edges):
116
115
cross_product = itertools .product (* edge_args_lists )
117
116
118
117
for element in cross_product :
119
- exec_xyrefs = execute_and_node_inner (node , element )
118
+ exec_xyref_ptrs = execute_and_node_inner (node , element )
120
119
for post_edge in post_edges :
121
120
if post_edge not in edge_args .keys ():
122
121
edge_args [post_edge ] = []
123
- edge_args [post_edge ].extend (exec_xyrefs )
122
+ edge_args [post_edge ].extend (exec_xyref_ptrs )
124
123
125
124
126
- def execute_or_node (node , pre_edges , edge_args , post_edges , mode : ExecutionType ):
127
- for pre_edge in pre_edges :
128
- Xyrefs = edge_args [pre_edge ]
129
- exec_xyrefs = []
130
- for xy_ref in Xyrefs :
131
- xy_ref_list = ray .get (xy_ref )
132
- for xy_ref in xy_ref_list :
133
- inner_result = execute_or_node_inner .remote (node , mode , xy_ref )
134
- exec_xyrefs .append (inner_result )
125
+ def execute_pipeline (pipeline : Pipeline , mode : ExecutionType , in_args : dict ):
126
+ nodes_by_level = pipeline .get_nodes_by_level ()
135
127
136
- for post_edge in post_edges :
137
- if post_edge not in edge_args .keys ():
138
- edge_args [post_edge ] = []
139
- edge_args [post_edge ].extend (exec_xyrefs )
128
+ # track args per edge
129
+ edge_args = {}
130
+ for node , node_in_args in in_args .items ():
131
+ pre_edges = pipeline .get_pre_edges (node )
132
+ for pre_edge in pre_edges :
133
+ edge_args [pre_edge ] = node_in_args
134
+
135
+ for nodes in nodes_by_level :
136
+ for node in nodes :
137
+ pre_edges = pipeline .get_pre_edges (node )
138
+ post_edges = pipeline .get_post_edges (node )
139
+ if not node .get_and_flag ():
140
+ execute_or_node (node , pre_edges , edge_args , post_edges , mode )
141
+ elif node .get_and_flag ():
142
+ execute_and_node (node , pre_edges , edge_args , post_edges )
143
+
144
+ out_args = {}
145
+ last_level_nodes = nodes_by_level [pipeline .compute_max_level ()]
146
+ for last_level_node in last_level_nodes :
147
+ edge = Edge (last_level_node , None )
148
+ out_args [last_level_node ] = edge_args [edge ]
149
+
150
+ return out_args
0 commit comments