Skip to content

Commit 4d7e2a1

Browse files
Kiuk Chungfacebook-github-bot
Kiuk Chung
authored andcommitted
register fb schedulers using entry_point (pytorch#12)
Summary: Pull Request resolved: pytorch#12 NOTE: I cleaned up a few tsm references while I was at it since removing `//torchx/schedulers/fb:registry` buck target ended up with a ton of autodeps changes that touched a lot of TARGETS files. I recommend taking a look at `torchx_fb.dist-info/entry_points.txt` and `torchx/util/entrypoint.py` for the CORE changes I originally intended for this diff. 1. Now that we have `torchx_fb.dist-info/entry_point.txt` setup to hook in plugin points for Facebook, move the scheduler registry to entrypoint so that we don't have to mess with `base_path` and have autodeps hate us. 1. cleaned up a few places were we still have a back reference to tsm. Reviewed By: tierex Differential Revision: D28622469 fbshipit-source-id: 907f384d518a695bbd9087a033d771450dcb07c8
1 parent b7e1fd0 commit 4d7e2a1

File tree

11 files changed

+188
-192
lines changed

11 files changed

+188
-192
lines changed

torchx/cli/cmd_log.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
from typing import Optional
1515
from urllib.parse import urlparse
1616

17-
import torchx.specs.lib as torchx
1817
from pyre_extensions import none_throws
18+
from torchx import specs
1919
from torchx.cli.cmd_base import SubCommand
20-
from torchx.runner import Runner
20+
from torchx.runner import Runner, get_runner
21+
from torchx.specs.api import make_app_handle
22+
2123

2224
GREEN = "\033[32m"
2325
ENDC = "\033[0m"
@@ -64,10 +66,10 @@ def get_logs(identifier: str, regex: Optional[str], should_tail: bool = False) -
6466
app_id = path[1]
6567
role_name = path[2]
6668

67-
session_ = torchx.run(name=session_name)
68-
app_handle = torchx.make_app_handle(scheduler_backend, session_name, app_id)
69+
runner = get_runner(name=session_name)
70+
app_handle = make_app_handle(scheduler_backend, session_name, app_id)
6971

70-
app = none_throws(session_.describe(app_handle))
72+
app = none_throws(runner.describe(app_handle))
7173

7274
if len(path) == 4:
7375
replica_ids = [int(id) for id in path[3].split(",") if id]
@@ -98,7 +100,7 @@ def get_logs(identifier: str, regex: Optional[str], should_tail: bool = False) -
98100
thread = threading.Thread(
99101
target=print_log_lines,
100102
args=(
101-
session_,
103+
runner,
102104
app_handle,
103105
role_name,
104106
replica_id,
@@ -126,7 +128,7 @@ def get_logs(identifier: str, regex: Optional[str], should_tail: bool = False) -
126128
raise threads_exceptions[0]
127129

128130

129-
def find_role_replicas(app: torchx.Application, role_name: str) -> Optional[int]:
131+
def find_role_replicas(app: specs.Application, role_name: str) -> Optional[int]:
130132
for role in app.roles:
131133
if role_name == role.name:
132134
return role.num_replicas

torchx/cli/cmd_run.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -100,18 +100,20 @@ def _parse_run_config(arg: str) -> specs.RunConfig:
100100
_CONFIG_EXT = ".torchx"
101101

102102

103+
def get_abspath(relpath: str) -> str:
104+
module = __name__.replace(".", path.sep) # torchx/cli/cmd_run
105+
module_path, _ = path.splitext(__file__) # $root/torchx/cli/cmd_run
106+
root = module_path.replace(module, "")
107+
return path.join(root, relpath)
108+
109+
103110
def get_file_contents(conf_file: str) -> Optional[str]:
104111
"""
105112
Reads the ``conf_file`` relative to the root of the project.
106113
Returns ``None`` if ``$root/$conf_file`` does not exist.
107114
Example: ``get_file("torchx/cli/config/foo.txt")``
108115
"""
109-
110-
module = __name__.replace(".", path.sep) # torchx/cli/cmd_run
111-
module_path, _ = path.splitext(__file__) # $root/torchx/cli/cmd_run
112-
root = module_path.replace(module, "")
113-
abspath = path.join(root, conf_file)
114-
116+
abspath = get_abspath(conf_file)
115117
if path.exists(abspath):
116118
with open(abspath, "r") as f:
117119
return f.read()
@@ -148,8 +150,12 @@ def read_conf_file(conf_file: str) -> str:
148150

149151

150152
def _builtins() -> List[str]:
153+
config_dir = entrypoints.load("torchx.file", "get_dir_path", default=get_abspath)(
154+
_CONFIG_DIR
155+
)
156+
151157
builtins: List[str] = []
152-
for f in os.listdir(_CONFIG_DIR):
158+
for f in os.listdir(config_dir):
153159
_, extension = os.path.splitext(f)
154160
if f.endswith(_CONFIG_EXT):
155161
builtins.append(f)

torchx/cli/test/cmd_log_test.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
import io
99
import unittest
10-
from typing import Optional, Iterator
11-
from unittest.mock import patch, MagicMock
10+
from typing import Iterator, Optional
11+
from unittest.mock import MagicMock, patch
1212

1313
from torchx.cli.cmd_log import ENDC, GREEN, get_logs
1414
from torchx.specs.api import Application, Role, parse_app_handle
@@ -22,11 +22,11 @@ class SentinelError(Exception):
2222
pass
2323

2424

25-
SESSION = "torchx.specs.lib.run"
25+
RUNNER = "torchx.cli.cmd_log.get_runner"
2626

2727

28-
class MockSession:
29-
def __call__(self, name: Optional[str] = None) -> "MockSession":
28+
class MockRunner:
29+
def __call__(self, name: Optional[str] = None) -> "MockRunner":
3030
return self
3131

3232
def describe(self, app_handle: str) -> Application:
@@ -61,10 +61,10 @@ def test_cmd_log_bad_job_identifier(self, exit_mock: MagicMock) -> None:
6161
get_logs("local:///SparseNNApplication/", "QPS.*")
6262
exit_mock.assert_called_once_with(1)
6363

64-
@patch(SESSION, new_callable=MockSession)
64+
@patch(RUNNER, new_callable=MockRunner)
6565
@patch("sys.exit", side_effect=SentinelError)
6666
def test_cmd_log_unknown_role(
67-
self, exit_mock: MagicMock, session_mock: MagicMock
67+
self, exit_mock: MagicMock, mock_runner: MagicMock
6868
) -> None:
6969
with self.assertRaises(SentinelError):
7070
get_logs(
@@ -74,10 +74,10 @@ def test_cmd_log_unknown_role(
7474

7575
exit_mock.assert_called_once_with(1)
7676

77-
@patch(SESSION, new_callable=MockSession)
77+
@patch(RUNNER, new_callable=MockRunner)
7878
@patch("sys.stdout", new_callable=io.StringIO)
7979
def test_cmd_log_all_replicas(
80-
self, stdout_mock: MagicMock, session_mock: MagicMock
80+
self, stdout_mock: MagicMock, mock_runner: MagicMock
8181
) -> None:
8282
get_logs("local://test-session/SparseNNApplication/trainer", regex="INFO")
8383
self.assertSetEqual(
@@ -92,10 +92,10 @@ def test_cmd_log_all_replicas(
9292
set(stdout_mock.getvalue().split("\n")),
9393
)
9494

95-
@patch(SESSION, new_callable=MockSession)
95+
@patch(RUNNER, new_callable=MockRunner)
9696
@patch("sys.stdout", new_callable=io.StringIO)
9797
def test_cmd_log_one_replica(
98-
self, stdout_mock: MagicMock, session_mock: MagicMock
98+
self, stdout_mock: MagicMock, mock_runner: MagicMock
9999
) -> None:
100100
get_logs("local://test-session/SparseNNApplication/trainer/0", regex=None)
101101
self.assertSetEqual(
@@ -110,10 +110,10 @@ def test_cmd_log_one_replica(
110110
set(stdout_mock.getvalue().split("\n")),
111111
)
112112

113-
@patch(SESSION, new_callable=MockSession)
113+
@patch(RUNNER, new_callable=MockRunner)
114114
@patch("sys.stdout", new_callable=io.StringIO)
115115
def test_cmd_log_some_replicas(
116-
self, stdout_mock: MagicMock, session_mock: MagicMock
116+
self, stdout_mock: MagicMock, mock_runner: MagicMock
117117
) -> None:
118118
get_logs("local://test-session/SparseNNApplication/trainer/0,2", regex="WARN")
119119
self.assertSetEqual(
@@ -127,11 +127,11 @@ def test_cmd_log_some_replicas(
127127
set(stdout_mock.getvalue().split("\n")),
128128
)
129129

130-
@patch(SESSION, new_callable=MockSession)
131-
def test_print_log_lines_throws(self, session_mock: MagicMock) -> None:
130+
@patch(RUNNER, new_callable=MockRunner)
131+
def test_print_log_lines_throws(self, mock_runner: MagicMock) -> None:
132132
# makes sure that when the function executed in the threadpool
133133
# errors out; we raise the exception all the way through
134-
with patch.object(session_mock, "log_lines") as log_lines_mock:
134+
with patch.object(mock_runner, "log_lines") as log_lines_mock:
135135
log_lines_mock.side_effect = RuntimeError
136136
with self.assertRaises(RuntimeError):
137137
get_logs(

torchx/cli/test/cmd_runopts_test.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@
99
import unittest
1010

1111
from torchx.cli.cmd_runopts import CmdRunopts
12-
13-
# @manual=//torchx/schedulers/fb:registry
14-
from torchx.schedulers.registry import get_schedulers
12+
from torchx.schedulers import get_schedulers
1513

1614

1715
class CmdRunOptsTest(unittest.TestCase):
@@ -22,7 +20,7 @@ def test_run(self) -> None:
2220
cmd_runopts = CmdRunopts()
2321
cmd_runopts.add_arguments(parser)
2422

25-
schedulers = get_schedulers("test").keys()
23+
schedulers = get_schedulers(session_name="test").keys()
2624
test_configs = [[]] + [[scheduler] for scheduler in schedulers]
2725
for scheduler in test_configs:
2826
with self.subTest(scheduler=scheduler):

torchx/runner/api.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313

1414
from pyre_extensions import none_throws
1515
from torchx.runner.events import log_event
16+
from torchx.schedulers import get_schedulers
1617
from torchx.schedulers.api import Scheduler
17-
from torchx.schedulers.registry import get_schedulers
1818
from torchx.specs.api import (
1919
NULL_CONTAINER,
2020
AppDryRunInfo,
@@ -426,6 +426,5 @@ def get_runner(name: Optional[str] = None, **scheduler_params: Any) -> Runner:
426426
if not name:
427427
name = f"torchx_{getpass.getuser()}"
428428

429-
scheduler_params["session_name"] = name
430-
schedulers = get_schedulers(**scheduler_params)
429+
schedulers = get_schedulers(session_name=name, **scheduler_params)
431430
return Runner(name, schedulers)

torchx/schedulers/__init__.py

+28
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,31 @@
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
7+
8+
from typing import Dict
9+
10+
import torchx.schedulers.local_scheduler as local_scheduler
11+
from torchx.schedulers.api import Scheduler
12+
from torchx.specs.api import SchedulerBackend
13+
from torchx.util.entrypoints import load_group
14+
15+
16+
def get_schedulers(
17+
session_name: str,
18+
# pyre-ignore[2]
19+
**scheduler_params
20+
) -> Dict[SchedulerBackend, Scheduler]:
21+
22+
schedulers = load_group(
23+
"torchx.schedulers",
24+
default={
25+
"local": local_scheduler.create_scheduler,
26+
"default": local_scheduler.create_scheduler,
27+
},
28+
ignore_missing=True,
29+
)
30+
31+
return {
32+
scheduler_backend: scheduler_factory_method(session_name, **scheduler_params)
33+
for scheduler_backend, scheduler_factory_method in schedulers.items()
34+
}

torchx/schedulers/registry.py

-21
This file was deleted.

torchx/schedulers/test/registry_test.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,26 @@
66
# LICENSE file in the root directory of this source tree.
77

88
import unittest
9+
from typing import Any, Dict, Optional
10+
from unittest.mock import MagicMock, patch
911

12+
from torchx.schedulers import get_schedulers
1013
from torchx.schedulers.local_scheduler import LocalScheduler
11-
from torchx.schedulers.registry import get_schedulers
14+
15+
16+
class spy_load_group:
17+
def __call__(
18+
self,
19+
group: str,
20+
default: Dict[str, Any],
21+
ignore_missing: Optional[bool] = False,
22+
) -> Dict[str, Any]:
23+
return default
1224

1325

1426
class SchedulersTest(unittest.TestCase):
15-
def test_get_local_schedulers(self) -> None:
27+
@patch("torchx.schedulers.load_group", new_callable=spy_load_group)
28+
def test_get_local_schedulers(self, mock_load_group: MagicMock) -> None:
1629
schedulers = get_schedulers(session_name="test_session")
1730
self.assertTrue(isinstance(schedulers["local"], LocalScheduler))
1831
self.assertTrue(isinstance(schedulers["default"], LocalScheduler))

0 commit comments

Comments
 (0)