6
6
7
7
import argparse
8
8
import logging
9
- import sys
10
- from typing import Callable , Optional
11
9
12
10
from tabulate import tabulate
13
11
14
12
from torchx .cli .cmd_base import SubCommand
15
13
from torchx .runner .api import get_configured_trackers
16
14
from torchx .tracker .api import build_trackers , TrackerBase
17
- from torchx .util .types import none_throws
18
15
19
16
logger : logging .Logger = logging .getLogger (__name__ )
20
17
21
18
22
- def _requires_tracker (
23
- command : Callable [["CmdTracker" , argparse .Namespace ], None ]
24
- ) -> Callable [["CmdTracker" , argparse .Namespace ], None ]:
25
- """Checks that command has valid tracker setup"""
26
-
27
- def wrapper (self : "CmdTracker" , args : argparse .Namespace ) -> None :
28
- if not self .tracker :
29
- logger .error ("Exiting since no trackers were configured." )
30
- sys .exit (1 )
31
- command (self , args )
32
-
33
- return wrapper
34
-
35
-
36
19
class CmdTracker (SubCommand ):
37
20
"""
38
21
Prototype TorchX tracker subcommand that allows querying data by
@@ -49,30 +32,28 @@ class CmdTracker(SubCommand):
49
32
def __init__ (self ) -> None :
50
33
"""
51
34
Queries available tracker implementations and uses the first available one.
52
-
53
- Since the instance needs to be available to setup torchx arguments, subcommands
54
- utilize `_requires_tracker()` annotation to check that tracker is available
55
- when invoked.
56
35
"""
57
- self .tracker : Optional [TrackerBase ] = None
58
- configured_trackers = get_configured_trackers ()
59
- if configured_trackers :
60
- trackers = build_trackers (configured_trackers )
61
- if trackers :
62
- self .tracker = next (iter (trackers ))
63
- logger .info (f"Using { self .tracker } to query data" )
64
- else :
65
- logger .error ("No trackers were configured!" )
36
+
37
+ @property
38
+ def tracker (self ) -> TrackerBase :
39
+ trackers = list (build_trackers (get_configured_trackers ()))
40
+ if trackers :
41
+ logger .info (f"Using `{ trackers [0 ]} ` tracker to query data" )
42
+ return trackers [0 ]
43
+ else :
44
+ raise RuntimeError (
45
+ "No trackers configured."
46
+ " See: https://pytorch.org/torchx/latest/runtime/tracking.html"
47
+ )
66
48
67
49
def add_list_job_arguments (self , subparser : argparse .ArgumentParser ) -> None :
68
50
subparser .add_argument (
69
51
"--parent-run-id" , type = str , help = "Optional job parent run ID"
70
52
)
71
53
72
- @_requires_tracker
73
54
def list_jobs_command (self , args : argparse .Namespace ) -> None :
74
55
parent_run_id = args .parent_run_id
75
- job_ids = none_throws ( self .tracker ) .run_ids (parent_run_id = parent_run_id )
56
+ job_ids = self .tracker .run_ids (parent_run_id = parent_run_id )
76
57
77
58
tabulated_job_ids = [[job_id ] for job_id in job_ids ]
78
59
print (tabulate (tabulated_job_ids , headers = ["JOB ID" ]))
@@ -91,17 +72,15 @@ def add_job_lineage_arguments(self, subparser: argparse.ArgumentParser) -> None:
91
72
)
92
73
subparser .add_argument ("RUN_ID" , type = str , help = "Job run ID" )
93
74
94
- @_requires_tracker
95
75
def job_lineage_command (self , args : argparse .Namespace ) -> None :
96
76
raise NotImplementedError ("" )
97
77
98
78
def add_metadata_arguments (self , subparser : argparse .ArgumentParser ) -> None :
99
79
subparser .add_argument ("RUN_ID" , type = str , help = "Job run ID" )
100
80
101
- @_requires_tracker
102
81
def list_metadata_command (self , args : argparse .Namespace ) -> None :
103
82
run_id = args .RUN_ID
104
- metadata = none_throws ( self .tracker ) .metadata (run_id )
83
+ metadata = self .tracker .metadata (run_id )
105
84
print_data = [[k , v ] for k , v in metadata .items ()]
106
85
107
86
print (tabulate (print_data , headers = ["ID" , "VALUE" ]))
@@ -113,21 +92,16 @@ def add_artifacts_arguments(self, subparser: argparse.ArgumentParser) -> None:
113
92
114
93
subparser .add_argument ("RUN_ID" , type = str , help = "Job run ID" )
115
94
116
- @_requires_tracker
117
95
def list_artifacts_command (self , args : argparse .Namespace ) -> None :
118
96
run_id = args .RUN_ID
119
97
artifact_filter = args .artifact
120
98
121
- artifacts = none_throws (self .tracker ).artifacts (run_id )
122
- artifacts = artifacts .values ()
99
+ artifacts = list (self .tracker .artifacts (run_id ).values ())
123
100
124
101
if artifact_filter :
125
- artifacts = [
126
- artifact for artifact in artifacts if artifact .name == artifact_filter
127
- ]
128
- print_data = [
129
- [artifact .name , artifact .path , artifact .metadata ] for artifact in artifacts
130
- ]
102
+ artifacts = [a for a in artifacts if a .name == artifact_filter ]
103
+
104
+ print_data = [[a .name , a .path , a .metadata ] for a in artifacts ]
131
105
132
106
print (tabulate (print_data , headers = ["ARTIFACT" , "PATH" , "METADATA" ]))
133
107
0 commit comments