Skip to content

Commit 47976f7

Browse files
authored
(torchx/tracker)(bugfix) honor tracker config in the usual .torchxconfig load hierarchy + lazy load tracker object creation in cmd_tracker (pytorch#711)
1 parent 5b8022b commit 47976f7

15 files changed

+261
-289
lines changed

torchx/cli/argparse_util.py

-5
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from argparse import Action, ArgumentParser, Namespace
8-
from pathlib import Path
98
from typing import Any, Dict, Optional, Sequence, Text
109

1110
from torchx.runner import config
1211

1312

14-
CONFIG_DIRS = [str(Path.home()), str(Path.cwd())]
15-
16-
1713
class _torchxconfig(Action):
1814
"""
1915
Custom argparse action that loads default torchx CLI options
@@ -40,7 +36,6 @@ def __init__(
4036
config.get_configs(
4137
prefix="cli",
4238
name=subcmd,
43-
dirs=CONFIG_DIRS,
4439
),
4540
)
4641

torchx/cli/cmd_run.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from typing import Dict, List, Optional, Tuple
1616

1717
import torchx.specs as specs
18-
from torchx.cli.argparse_util import CONFIG_DIRS, torchxconfig_run
18+
from torchx.cli.argparse_util import torchxconfig_run
1919
from torchx.cli.cmd_base import SubCommand
2020
from torchx.cli.cmd_log import get_logs
2121
from torchx.runner import config, get_runner, Runner
@@ -182,12 +182,11 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:
182182

183183
scheduler_opts = runner.scheduler_run_opts(args.scheduler)
184184
cfg = scheduler_opts.cfg_from_str(args.scheduler_args)
185-
config.apply(scheduler=args.scheduler, cfg=cfg, dirs=CONFIG_DIRS)
185+
config.apply(scheduler=args.scheduler, cfg=cfg)
186186

187187
component, component_args = _parse_component_name_and_args(
188188
args.component_name_and_args,
189189
none_throws(self._subparser),
190-
dirs=CONFIG_DIRS,
191190
)
192191
try:
193192
if args.dryrun:
@@ -243,7 +242,8 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:
243242

244243
def run(self, args: argparse.Namespace) -> None:
245244
os.environ["TORCHX_CONTEXT_NAME"] = os.getenv("TORCHX_CONTEXT_NAME", "cli_run")
246-
component_defaults = load_sections(prefix="component", dirs=CONFIG_DIRS)
245+
component_defaults = load_sections(prefix="component")
246+
247247
with get_runner(component_defaults=component_defaults) as runner:
248248
self._run(runner, args)
249249

torchx/cli/cmd_tracker.py

+18-44
Original file line numberDiff line numberDiff line change
@@ -6,33 +6,16 @@
66

77
import argparse
88
import logging
9-
import sys
10-
from typing import Callable, Optional
119

1210
from tabulate import tabulate
1311

1412
from torchx.cli.cmd_base import SubCommand
1513
from torchx.runner.api import get_configured_trackers
1614
from torchx.tracker.api import build_trackers, TrackerBase
17-
from torchx.util.types import none_throws
1815

1916
logger: logging.Logger = logging.getLogger(__name__)
2017

2118

22-
def _requires_tracker(
23-
command: Callable[["CmdTracker", argparse.Namespace], None]
24-
) -> Callable[["CmdTracker", argparse.Namespace], None]:
25-
"""Checks that command has valid tracker setup"""
26-
27-
def wrapper(self: "CmdTracker", args: argparse.Namespace) -> None:
28-
if not self.tracker:
29-
logger.error("Exiting since no trackers were configured.")
30-
sys.exit(1)
31-
command(self, args)
32-
33-
return wrapper
34-
35-
3619
class CmdTracker(SubCommand):
3720
"""
3821
Prototype TorchX tracker subcommand that allows querying data by
@@ -49,30 +32,28 @@ class CmdTracker(SubCommand):
4932
def __init__(self) -> None:
5033
"""
5134
Queries available tracker implementations and uses the first available one.
52-
53-
Since the instance needs to be available to setup torchx arguments, subcommands
54-
utilize `_requires_tracker()` annotation to check that tracker is available
55-
when invoked.
5635
"""
57-
self.tracker: Optional[TrackerBase] = None
58-
configured_trackers = get_configured_trackers()
59-
if configured_trackers:
60-
trackers = build_trackers(configured_trackers)
61-
if trackers:
62-
self.tracker = next(iter(trackers))
63-
logger.info(f"Using {self.tracker} to query data")
64-
else:
65-
logger.error("No trackers were configured!")
36+
37+
@property
38+
def tracker(self) -> TrackerBase:
39+
trackers = list(build_trackers(get_configured_trackers()))
40+
if trackers:
41+
logger.info(f"Using `{trackers[0]}` tracker to query data")
42+
return trackers[0]
43+
else:
44+
raise RuntimeError(
45+
"No trackers configured."
46+
" See: https://pytorch.org/torchx/latest/runtime/tracking.html"
47+
)
6648

6749
def add_list_job_arguments(self, subparser: argparse.ArgumentParser) -> None:
6850
subparser.add_argument(
6951
"--parent-run-id", type=str, help="Optional job parent run ID"
7052
)
7153

72-
@_requires_tracker
7354
def list_jobs_command(self, args: argparse.Namespace) -> None:
7455
parent_run_id = args.parent_run_id
75-
job_ids = none_throws(self.tracker).run_ids(parent_run_id=parent_run_id)
56+
job_ids = self.tracker.run_ids(parent_run_id=parent_run_id)
7657

7758
tabulated_job_ids = [[job_id] for job_id in job_ids]
7859
print(tabulate(tabulated_job_ids, headers=["JOB ID"]))
@@ -91,17 +72,15 @@ def add_job_lineage_arguments(self, subparser: argparse.ArgumentParser) -> None:
9172
)
9273
subparser.add_argument("RUN_ID", type=str, help="Job run ID")
9374

94-
@_requires_tracker
9575
def job_lineage_command(self, args: argparse.Namespace) -> None:
9676
raise NotImplementedError("")
9777

9878
def add_metadata_arguments(self, subparser: argparse.ArgumentParser) -> None:
9979
subparser.add_argument("RUN_ID", type=str, help="Job run ID")
10080

101-
@_requires_tracker
10281
def list_metadata_command(self, args: argparse.Namespace) -> None:
10382
run_id = args.RUN_ID
104-
metadata = none_throws(self.tracker).metadata(run_id)
83+
metadata = self.tracker.metadata(run_id)
10584
print_data = [[k, v] for k, v in metadata.items()]
10685

10786
print(tabulate(print_data, headers=["ID", "VALUE"]))
@@ -113,21 +92,16 @@ def add_artifacts_arguments(self, subparser: argparse.ArgumentParser) -> None:
11392

11493
subparser.add_argument("RUN_ID", type=str, help="Job run ID")
11594

116-
@_requires_tracker
11795
def list_artifacts_command(self, args: argparse.Namespace) -> None:
11896
run_id = args.RUN_ID
11997
artifact_filter = args.artifact
12098

121-
artifacts = none_throws(self.tracker).artifacts(run_id)
122-
artifacts = artifacts.values()
99+
artifacts = list(self.tracker.artifacts(run_id).values())
123100

124101
if artifact_filter:
125-
artifacts = [
126-
artifact for artifact in artifacts if artifact.name == artifact_filter
127-
]
128-
print_data = [
129-
[artifact.name, artifact.path, artifact.metadata] for artifact in artifacts
130-
]
102+
artifacts = [a for a in artifacts if a.name == artifact_filter]
103+
104+
print_data = [[a.name, a.path, a.metadata] for a in artifacts]
131105

132106
print(tabulate(print_data, headers=["ARTIFACT", "PATH", "METADATA"]))
133107

torchx/cli/main.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import logging
8+
import os
89
import sys
910
from argparse import ArgumentParser
1011
from typing import Dict, List
@@ -75,9 +76,9 @@ def create_parser(subcmds: Dict[str, SubCommand]) -> ArgumentParser:
7576
parser = ArgumentParser(description="torchx CLI")
7677
parser.add_argument(
7778
"--log_level",
78-
type=int,
79+
type=str,
7980
help="Python logging log level",
80-
default=logging.INFO,
81+
default=os.getenv("LOGLEVEL", "WARNING"),
8182
)
8283
parser.add_argument(
8384
"--version",

torchx/cli/test/argparse_util_test.py

+12-26
Original file line numberDiff line numberDiff line change
@@ -4,38 +4,24 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import shutil
8-
import tempfile
9-
import unittest
107
from argparse import ArgumentParser
11-
from pathlib import Path
128
from unittest import mock
139

1410
from torchx.cli import argparse_util
1511
from torchx.cli.argparse_util import torchxconfig_run
12+
from torchx.test.fixtures import TestWithTmpDir
1613

14+
DEFAULT_CONFIG_DIRS = "torchx.runner.config.DEFAULT_CONFIG_DIRS"
1715

18-
CONFIG_DIRS = "torchx.cli.argparse_util.CONFIG_DIRS"
19-
20-
21-
class ArgparseUtilTest(unittest.TestCase):
22-
def _write(self, filename: str, content: str) -> Path:
23-
f = Path(self.test_dir) / filename
24-
f.parent.mkdir(parents=True, exist_ok=True)
25-
with open(f, "w") as fp:
26-
fp.write(content)
27-
return f
2816

17+
class ArgparseUtilTest(TestWithTmpDir):
2918
def setUp(self) -> None:
30-
self.test_dir = tempfile.mkdtemp(prefix="torchx_argparse_util_test")
19+
super().setUp()
3120
argparse_util._torchxconfig._subcmd_configs.clear()
3221

33-
def tearDown(self) -> None:
34-
shutil.rmtree(self.test_dir)
35-
3622
def test_torchxconfig_action(self) -> None:
37-
with mock.patch(CONFIG_DIRS, [self.test_dir]):
38-
self._write(
23+
with mock.patch(DEFAULT_CONFIG_DIRS, [str(self.tmpdir)]):
24+
self.write(
3925
".torchxconfig",
4026
"""
4127
[cli:run]
@@ -64,8 +50,8 @@ def test_torchxconfig_action(self) -> None:
6450
self.assertEqual("baz", args.workspace)
6551

6652
def test_torchxconfig_action_argparse_default(self) -> None:
67-
with mock.patch(CONFIG_DIRS, [self.test_dir]):
68-
self._write(
53+
with mock.patch(DEFAULT_CONFIG_DIRS, [str(self.tmpdir)]):
54+
self.write(
6955
".torchxconfig",
7056
"""
7157
[cli:run]
@@ -89,8 +75,8 @@ def test_torchxconfig_action_argparse_default(self) -> None:
8975
self.assertEqual("foo", args.workspace)
9076

9177
def test_torchxconfig_action_required(self) -> None:
92-
with mock.patch(CONFIG_DIRS, [self.test_dir]):
93-
self._write(
78+
with mock.patch(DEFAULT_CONFIG_DIRS, [str(self.tmpdir)]):
79+
self.write(
9480
".torchxconfig",
9581
"""
9682
[cli:run]
@@ -120,8 +106,8 @@ def test_torchxconfig_action_required(self) -> None:
120106

121107
def test_torchxconfig_action_aliases(self) -> None:
122108
# for aliases, the config file needs to declare the original arg
123-
with mock.patch(CONFIG_DIRS, [self.test_dir]):
124-
self._write(
109+
with mock.patch(DEFAULT_CONFIG_DIRS, [str(self.tmpdir)]):
110+
self.write(
125111
".torchxconfig",
126112
"""
127113
[cli:run]

torchx/cli/test/cmd_run_test.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,9 @@ def test_run_missing(self) -> None:
134134
def test_conf_file_missing(self) -> None:
135135
# guards against existing .torchxconfig files
136136
# in user's $HOME or the CWD where the test is launched from
137-
with patch("torchx.cli.cmd_run.CONFIG_DIRS", return_value=[self.tmpdir]):
137+
with patch(
138+
"torchx.runner.config.DEFAULT_CONFIG_DIRS", return_value=[self.tmpdir]
139+
):
138140
with self.assertRaises(SystemExit):
139141
args = self.parser.parse_args(
140142
[

0 commit comments

Comments
 (0)