Skip to content

Commit 5707d33

Browse files
d4l3kfacebook-github-bot
authored andcommitted
added KFP adapter for new torchx.specs.api interface (#10)
Summary: This creates a KFP adapter for the new specs interface. ``` import torchx.specs.api as torchx from kfp import compiler, components, dsl app = torchx.Application(...) kfp_copy: Callable = component_from_app(app) def pipeline() -> dsl.PipelineParam: a = kfp_copy() b = kfp_copy() b.after(a) return b compiler.Compiler().compile(pipeline, "pipeline.zip") ``` Pull Request resolved: pytorch#10 Reviewed By: kiukchung Differential Revision: D28658161 Pulled By: d4l3k fbshipit-source-id: 39969b6da9ab651c71dd348db67f610dbf981c9a
1 parent e5e45e8 commit 5707d33

File tree

2 files changed

+106
-3
lines changed

2 files changed

+106
-3
lines changed

torchx/pipelines/kfp/adapter.py

+41-1
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@
77

88
import copy
99
import os
10-
from typing import Type, Callable, List, Optional, Dict
10+
from typing import Callable, Dict, List, Optional, Type
1111

1212
import yaml
1313
from kfp import components, dsl
1414
from torchx.runtime.component import Component, is_optional
15+
from torchx.specs import api
1516

1617
from .version import __version__ as __version__ # noqa F401
1718

19+
1820
TORCHX_CONTAINER_ENV: str = "TORCHX_CONTAINER"
1921
TORCHX_CONTAINER: str = os.getenv(
2022
TORCHX_CONTAINER_ENV,
@@ -111,3 +113,41 @@ def outputs(self) -> Dict[str, dsl.PipelineParam]:
111113
@property
112114
def output(self) -> dsl.PipelineParam:
113115
...
116+
117+
118+
def component_spec_from_app(app: api.Application) -> str:
119+
assert len(app.roles) == 1, f"KFP adapter only support one role, got {app.roles}"
120+
121+
role = app.roles[0]
122+
assert (
123+
role.num_replicas == 1
124+
), f"KFP adapter only supports one replica, got {app.num_replicas}"
125+
assert role.container != api.NULL_CONTAINER, "missing container for KFP"
126+
127+
container = role.container
128+
assert container.base_image is None, "KFP adapter does not support base_image"
129+
assert (
130+
container.resources == api.NULL_RESOURCE
131+
), "KFP adapter requires you to specify resources in the pipeline"
132+
assert len(container.port_map) == 0, "KFP adapter does not support port_map"
133+
134+
command = [role.entrypoint, *role.args]
135+
136+
spec = {
137+
"name": f"{app.name}-{role.name}",
138+
"description": f"KFP wrapper for TorchX component {app.name}, role {role.name}",
139+
"implementation": {
140+
"container": {
141+
"image": container.image,
142+
"command": command,
143+
"env": role.env,
144+
}
145+
},
146+
}
147+
return yaml.dump(spec)
148+
149+
150+
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
151+
def component_from_app(app: api.Application) -> Callable:
152+
spec = component_spec_from_app(app)
153+
return components.load_component_from_text(spec)

torchx/pipelines/kfp/test/adapter_test.py

+65-2
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,18 @@
88
import os.path
99
import tempfile
1010
import unittest
11-
from typing import TypedDict, Optional
11+
from typing import Callable, Optional, TypedDict
1212

1313
from kfp import compiler, components, dsl
1414
from torchx.apps.io.copy import Copy
15-
from torchx.pipelines.kfp.adapter import component_spec, TorchXComponent
15+
from torchx.pipelines.kfp.adapter import (
16+
TorchXComponent,
17+
component_from_app,
18+
component_spec,
19+
component_spec_from_app,
20+
)
1621
from torchx.runtime.component import Component
22+
from torchx.specs import api
1723

1824

1925
class Config(TypedDict):
@@ -103,3 +109,60 @@ class KFPCopy(TorchXComponent, component=Copy, image="foo"):
103109
print(copy)
104110
# pyre-fixme[16]: `KFPCopy` has no attribute `component_ref`.
105111
self.assertEqual(copy.component_ref.spec.implementation.container.image, "foo")
112+
113+
114+
class KFPSpecsTest(unittest.TestCase):
115+
"""
116+
tests KFP components using torchx.specs.api
117+
"""
118+
119+
def _test_app(self) -> api.Application:
120+
container = api.Container(image="pytorch/torchx:latest")
121+
trainer_role = (
122+
api.Role(name="trainer")
123+
.runs(
124+
"main",
125+
"--output-path",
126+
"blah",
127+
FOO="bar",
128+
)
129+
.on(container)
130+
.replicas(1)
131+
)
132+
133+
return api.Application("test").of(trainer_role)
134+
135+
def test_component_spec_from_app(self) -> None:
136+
app = self._test_app()
137+
138+
spec = component_spec_from_app(app)
139+
self.assertIsNotNone(components.load_component_from_text(spec))
140+
self.assertEqual(
141+
spec,
142+
"""description: KFP wrapper for TorchX component test, role trainer
143+
implementation:
144+
container:
145+
command:
146+
- main
147+
- --output-path
148+
- blah
149+
env:
150+
FOO: bar
151+
image: pytorch/torchx:latest
152+
name: test-trainer
153+
""",
154+
)
155+
156+
def test_pipeline(self) -> None:
157+
app = self._test_app()
158+
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
159+
kfp_copy: Callable = component_from_app(app)
160+
161+
def pipeline() -> dsl.PipelineParam:
162+
a = kfp_copy()
163+
b = kfp_copy()
164+
b.after(a)
165+
return b
166+
167+
with tempfile.TemporaryDirectory() as tmpdir:
168+
compiler.Compiler().compile(pipeline, os.path.join(tmpdir, "pipeline.zip"))

0 commit comments

Comments
 (0)