Skip to content

Commit 0bfcf59

Browse files
chcostGitHub Enterprise
authored and
GitHub Enterprise
committed
Merge pull request #15 from codeflare/develop
2 parents 250dd23 + 1dde22a commit 0bfcf59

26 files changed

+726
-456
lines changed
File renamed without changes.

com/ibm/research/ray/graph/Datamodel.py renamed to codeflare/pipelines/Datamodel.py

Lines changed: 79 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,40 +3,61 @@
33

44

55
class Xy:
6-
__X__ = None
7-
__y__ = None
6+
"""
7+
Holder class for Xy, where X is array-like and y is array-like. This is the base
8+
data structure for fully materialized X and y.
9+
"""
810

911
def __init__(self, X, y):
1012
self.__X__ = X
1113
self.__y__ = y
1214

15+
"""
16+
Returns the holder value of X
17+
"""
18+
1319
def get_x(self):
1420
return self.__X__
1521

22+
"""
23+
Returns the holder value of y
24+
"""
25+
1626
def get_y(self):
1727
return self.__y__
1828

1929

2030
class XYRef:
31+
"""
32+
Holder class that maintains a pointer/reference to X and y. The goal of this is to provide
33+
a holder to the object references of Ray. This is used for passing outputs from a transform/fit
34+
to the next stage of the pipeline. Since the references can be potentially in flight (or being
35+
computed), these holders are essential to the pipeline constructs.
36+
"""
37+
2138
def __init__(self, Xref, yref):
22-
self.Xref = Xref
23-
self.yref = yref
39+
self.__Xref__ = Xref
40+
self.__yref__ = yref
2441

2542
def get_Xref(self):
26-
return self.Xref
43+
"""
44+
Returns the object reference to X
45+
"""
46+
return self.__Xref__
2747

2848
def get_yref(self):
29-
return self.yref
30-
31-
32-
class AndFunc(ABC):
33-
@abstractmethod
34-
def eval(self, xy_list: list) -> Xy:
35-
raise NotImplementedError("Please implement this method")
49+
"""
50+
Returns the object reference to y
51+
"""
52+
return self.__yref__
3653

3754

3855
class Node(ABC):
39-
__node_name__ = None
56+
"""
57+
A node class that is an abstract one, this is capturing basic info re the Node.
58+
The hash code of this node is the name of the node and equality is defined if the
59+
node name and the type of the node match.
60+
"""
4061

4162
def __str__(self):
4263
return self.__node_name__
@@ -46,29 +67,71 @@ def get_and_flag(self):
4667
raise NotImplementedError("Please implement this method")
4768

4869
def __hash__(self):
70+
"""
71+
Hash code, defined as the hash code of the node name
72+
73+
:return: Hash code
74+
"""
4975
return self.__node_name__.__hash__()
5076

5177
def __eq__(self, other):
78+
"""
79+
Equality with another node, defined as the class names match and the
80+
node names match
81+
82+
:param other: Node to compare with
83+
:return: True if nodes are equal, else False
84+
"""
5285
return (
5386
self.__class__ == other.__class__ and
5487
self.__node_name__ == other.__node_name__
5588
)
5689

5790

5891
class OrNode(Node):
92+
"""
93+
Or node, which is the basic node that would be the equivalent of any SKlearn pipeline
94+
stage. This node is initialized with an estimator that needs to extend sklearn.BaseEstimator.
95+
"""
5996
__estimator__ = None
6097

6198
def __init__(self, node_name: str, estimator: BaseEstimator):
99+
"""
100+
Init the OrNode with the name of the node and the etimator.
101+
102+
:param node_name: Name of the node
103+
:param estimator: The base estimator
104+
"""
62105
self.__node_name__ = node_name
63106
self.__estimator__ = estimator
64107

65108
def get_estimator(self) -> BaseEstimator:
109+
"""
110+
Return the estimator that this was initialize with
111+
112+
:return: Estimator
113+
"""
66114
return self.__estimator__
67115

68116
def get_and_flag(self):
117+
"""
118+
A flag to check if node is AND or not. By definition, this is NOT
119+
an AND node.
120+
:return: False, always
121+
"""
69122
return False
70123

71124

125+
class AndFunc(ABC):
126+
"""
127+
Or nodes are init-ed from the
128+
"""
129+
130+
@abstractmethod
131+
def eval(self, xy_list: list) -> Xy:
132+
raise NotImplementedError("Please implement this method")
133+
134+
72135
class AndNode(Node):
73136
__andfunc__ = None
74137

@@ -127,10 +190,9 @@ def get_object_ref(self):
127190

128191

129192
class Pipeline:
130-
__pre_graph__ = {}
131-
__post_graph__ = {}
132-
__node_levels__ = None
133-
__level_nodes__ = None
193+
"""
194+
The pipeline class that defines the DAG structure composed of Node(s). The
195+
"""
134196

135197
def __init__(self):
136198
self.__pre_graph__ = {}
Lines changed: 84 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
import ray
22

3-
from com.ibm.research.ray.graph.Datamodel import OrNode
4-
from com.ibm.research.ray.graph.Datamodel import AndNode
5-
from com.ibm.research.ray.graph.Datamodel import Edge
6-
from com.ibm.research.ray.graph.Datamodel import Pipeline
7-
from com.ibm.research.ray.graph.Datamodel import XYRef
3+
from codeflare.pipelines.Datamodel import OrNode
4+
from codeflare.pipelines.Datamodel import AndNode
5+
from codeflare.pipelines.Datamodel import Edge
6+
from codeflare.pipelines.Datamodel import Pipeline
7+
from codeflare.pipelines.Datamodel import XYRef
8+
from codeflare.pipelines.Datamodel import Xy
89

910
import sklearn.base as base
1011
from enum import Enum
1112

1213

1314
class ExecutionType(Enum):
14-
TRAIN = 0,
15-
TEST = 1
15+
FIT = 0,
16+
PREDICT = 1,
17+
SCORE = 2
1618

1719

1820
@ray.remote
@@ -22,87 +24,84 @@ def execute_or_node_inner(node: OrNode, train_mode: ExecutionType, Xy: XYRef):
2224
X = ray.get(Xy.get_Xref())
2325
y = ray.get(Xy.get_yref())
2426

25-
if train_mode == ExecutionType.TRAIN:
27+
if train_mode == ExecutionType.FIT:
2628
if base.is_classifier(estimator) or base.is_regressor(estimator):
2729
# Always clone before fit, else fit is invalid
2830
cloned_estimator = base.clone(estimator)
2931
cloned_estimator.fit(X, y)
3032
# TODO: For now, make yref passthrough - this has to be fixed more comprehensively
3133
res_Xref = ray.put(cloned_estimator.predict(X))
32-
result = [XYRef(res_Xref, Xy.get_yref())]
34+
result = XYRef(res_Xref, Xy.get_yref())
3335
return result
3436
else:
3537
# No need to clone as it is a transform pass through on the fitted estimator
36-
res_Xref = ray.put(estimator.fit_transform(X))
37-
result = [XYRef(res_Xref, Xy.get_yref())]
38+
res_Xref = ray.put(estimator.fit_transform(X, y))
39+
result = XYRef(res_Xref, Xy.get_yref())
3840
return result
39-
elif train_mode == ExecutionType.TEST:
41+
elif train_mode == ExecutionType.SCORE:
42+
if base.is_classifier(estimator) or base.is_regressor(estimator):
43+
cloned_estimator = base.clone(estimator)
44+
cloned_estimator.fit(X, y)
45+
res_Xref = ray.put(cloned_estimator.score(X, y))
46+
result = XYRef(res_Xref, Xy.get_yref())
47+
return result
48+
else:
49+
# No need to clone as it is a transform pass through on the fitted estimator
50+
res_Xref = ray.put(estimator.fit_transform(X, y))
51+
result = XYRef(res_Xref, Xy.get_yref())
52+
return result
53+
elif train_mode == ExecutionType.PREDICT:
4054
# Test mode does not clone as it is a simple predict or transform
4155
if base.is_classifier(estimator) or base.is_regressor(estimator):
4256
res_Xref = estimator.predict(X)
43-
result = [XYRef(res_Xref, Xy.get_yref())]
57+
result = XYRef(res_Xref, Xy.get_yref())
4458
return result
4559
else:
4660
res_Xref = estimator.transform(X)
47-
result = [XYRef(res_Xref, Xy.get_yref())]
61+
result = XYRef(res_Xref, Xy.get_yref())
4862
return result
4963

5064

51-
###
52-
# in_args is a dict from Node to list of XYRefs
53-
###
54-
def execute_pipeline(pipeline: Pipeline, mode: ExecutionType, in_args: dict):
55-
nodes_by_level = pipeline.get_nodes_by_level()
56-
57-
# track args per edge
58-
edge_args = {}
59-
for node, node_in_args in in_args.items():
60-
pre_edges = pipeline.get_pre_edges(node)
61-
for pre_edge in pre_edges:
62-
edge_args[pre_edge] = node_in_args
63-
64-
for nodes in nodes_by_level:
65-
for node in nodes:
66-
pre_edges = pipeline.get_pre_edges(node)
67-
post_edges = pipeline.get_post_edges(node)
68-
if not node.get_and_flag():
69-
execute_or_node(node, pre_edges, edge_args, post_edges, mode)
70-
else:
71-
cross_product = execute_and_node(node, pre_edges, edge_args, post_edges)
72-
for element in cross_product:
73-
print(element)
74-
75-
out_args = {}
76-
last_level_nodes = nodes_by_level[pipeline.compute_max_level()]
77-
for last_level_node in last_level_nodes:
78-
edge = Edge(last_level_node, None)
79-
out_args[last_level_node] = edge_args[edge]
65+
def execute_or_node(node, pre_edges, edge_args, post_edges, mode: ExecutionType):
66+
for pre_edge in pre_edges:
67+
Xyref_ptrs = edge_args[pre_edge]
68+
exec_xyrefs = []
69+
for xy_ref_ptr in Xyref_ptrs:
70+
xy_ref = ray.get(xy_ref_ptr)
71+
inner_result = execute_or_node_inner.remote(node, mode, xy_ref)
72+
exec_xyrefs.append(inner_result)
8073

81-
return out_args
74+
for post_edge in post_edges:
75+
if post_edge not in edge_args.keys():
76+
edge_args[post_edge] = []
77+
edge_args[post_edge].extend(exec_xyrefs)
8278

8379

8480
@ray.remote
85-
def and_node_eval(and_func, xy_list):
86-
Xy = and_func.eval(xy_list)
87-
res_Xref = ray.put(Xy.get_x())
88-
res_yref = ray.put(Xy.get_y())
81+
def and_node_eval(and_func, Xyref_list):
82+
xy_list = []
83+
for Xyref in Xyref_list:
84+
X = ray.get(Xyref.get_Xref())
85+
y = ray.get(Xyref.get_yref())
86+
xy_list.append(Xy(X, y))
87+
88+
res_Xy = and_func.eval(xy_list)
89+
res_Xref = ray.put(res_Xy.get_x())
90+
res_yref = ray.put(res_Xy.get_y())
8991
return XYRef(res_Xref, res_yref)
9092

9193

92-
def execute_and_node_inner(node: AndNode, elements):
94+
def execute_and_node_inner(node: AndNode, Xyref_ptrs):
9395
and_func = node.get_and_func()
9496
result = []
9597

96-
for element in elements:
97-
xy_list = []
98-
for Xy in element:
99-
X = ray.get(Xy.get_Xref())
100-
y = ray.get(Xy.get_yref())
98+
Xyref_list = []
99+
for Xyref_ptr in Xyref_ptrs:
100+
Xyref = ray.get(Xyref_ptr)
101+
Xyref_list.append(Xyref)
101102

102-
Xy = Xy(X, y)
103-
xy_list.append(Xy)
104-
Xyref = and_node_eval(and_func, xy_list)
105-
result.append(Xyref)
103+
Xyref_ptr = and_node_eval.remote(and_func, Xyref_list)
104+
result.append(Xyref_ptr)
106105
return result
107106

108107

@@ -116,24 +115,36 @@ def execute_and_node(node, pre_edges, edge_args, post_edges):
116115
cross_product = itertools.product(*edge_args_lists)
117116

118117
for element in cross_product:
119-
exec_xyrefs = execute_and_node_inner(node, element)
118+
exec_xyref_ptrs = execute_and_node_inner(node, element)
120119
for post_edge in post_edges:
121120
if post_edge not in edge_args.keys():
122121
edge_args[post_edge] = []
123-
edge_args[post_edge].extend(exec_xyrefs)
122+
edge_args[post_edge].extend(exec_xyref_ptrs)
124123

125124

126-
def execute_or_node(node, pre_edges, edge_args, post_edges, mode: ExecutionType):
127-
for pre_edge in pre_edges:
128-
Xyrefs = edge_args[pre_edge]
129-
exec_xyrefs = []
130-
for xy_ref in Xyrefs:
131-
xy_ref_list = ray.get(xy_ref)
132-
for xy_ref in xy_ref_list:
133-
inner_result = execute_or_node_inner.remote(node, mode, xy_ref)
134-
exec_xyrefs.append(inner_result)
125+
def execute_pipeline(pipeline: Pipeline, mode: ExecutionType, in_args: dict):
126+
nodes_by_level = pipeline.get_nodes_by_level()
135127

136-
for post_edge in post_edges:
137-
if post_edge not in edge_args.keys():
138-
edge_args[post_edge] = []
139-
edge_args[post_edge].extend(exec_xyrefs)
128+
# track args per edge
129+
edge_args = {}
130+
for node, node_in_args in in_args.items():
131+
pre_edges = pipeline.get_pre_edges(node)
132+
for pre_edge in pre_edges:
133+
edge_args[pre_edge] = node_in_args
134+
135+
for nodes in nodes_by_level:
136+
for node in nodes:
137+
pre_edges = pipeline.get_pre_edges(node)
138+
post_edges = pipeline.get_post_edges(node)
139+
if not node.get_and_flag():
140+
execute_or_node(node, pre_edges, edge_args, post_edges, mode)
141+
elif node.get_and_flag():
142+
execute_and_node(node, pre_edges, edge_args, post_edges)
143+
144+
out_args = {}
145+
last_level_nodes = nodes_by_level[pipeline.compute_max_level()]
146+
for last_level_node in last_level_nodes:
147+
edge = Edge(last_level_node, None)
148+
out_args[last_level_node] = edge_args[edge]
149+
150+
return out_args
File renamed without changes.
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
Metadata-Version: 1.0
2-
Name: ray-graphs
2+
Name: codeflare-pipelines
33
Version: 1.0.0
4-
Summary: Ray
4+
Summary: Codeflare pipelines
55
Home-page: UNKNOWN
6-
Author: rganti
6+
Author: Raghu Ganti, Mudhakar Srivatsa
77
Author-email: rganti@us.ibm.com
8-
License: UNKNOWN
8+
License: Apache v2.0
99
Description: UNKNOWN
1010
Platform: UNKNOWN

0 commit comments

Comments
 (0)