Skip to content

Commit 5b8022b

Browse files
authored
(torchx/tracker) Add support for MLflowTracker (pytorch#707)
1 parent 3267fa9 commit 5b8022b

17 files changed

+1871
-28
lines changed

dev-requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ google-cloud-runtimeconfig>=0.33.2
1212
hydra-core
1313
ipython
1414
kfp==1.8.9
15+
mlflow-skinny
1516
moto==4.1.3
1617
pyre-extensions
1718
pyre-check

torchx/components/dist.py

+102-13
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262

6363
import torchx
6464
import torchx.specs as specs
65+
from torchx.components.structured_arg import StructuredJArgument, StructuredNameArgument
6566
from torchx.specs import macros
6667

6768
_TORCH_DEBUG_FLAGS: Dict[str, str] = {
@@ -80,12 +81,88 @@
8081
"""
8182

8283

84+
def spmd(
85+
*args: str,
86+
script: Optional[str] = None,
87+
m: Optional[str] = None,
88+
image: str = torchx.IMAGE,
89+
name: str = "/",
90+
h: str = "gpu.small",
91+
j: str = "1x1",
92+
env: Optional[Dict[str, str]] = None,
93+
max_retries: int = 0,
94+
mounts: Optional[List[str]] = None,
95+
debug: bool = False,
96+
) -> specs.AppDef:
97+
"""
98+
Usage (by script): torchx run spmd -j 2x8 -h aws_p4d.24xlarge --name my_experiment/trial_1 --script path/to/my/trainer.py -foo bar
99+
100+
Usage (by module): torchx run spmd -j 2x8 -h aws_p4d.24xlarge --name my_experiment/trial_1 -m path.to.my.trainer -foo bar
101+
102+
Usage (infer GPU count): torchx run spmd -j 2 -h p4d.24xlarge ... (same as -j 2x8)
103+
104+
Creates a torchx.specs.AppDef (Job Definition) for a Single-Process-Multiple-Data (SPMD)
105+
style application. See: https://en.wikipedia.org/wiki/Single_program,_multiple_data.
106+
107+
SPMD launches `n x m` (set via the `-j nxm` option) copies of the same program,
108+
where `n` is the number of nodes (hosts) and `m` is the number of processes on each node.
109+
110+
If you have a distributed PyTorch script (DDP, FSDP, RPC) use this component to launch
111+
the distributed application. You can also use `-j 1x1` to launch a single process application
112+
which would be equivalent to launching with regular `python` except that your application
113+
can safely call `torch.distributed.init_process_group(backend)`.
114+
115+
Note: For multi-node distributed runs, the hosts MUST have a network route to each other
116+
AND port 29500 should be open on all hosts. Please check your security group settings.
117+
118+
119+
Args:
120+
args: the arguments to the main module or script (e.g. my/trainer.py -foo bar)
121+
(for docker based runs) the script path must be relative to the WORKDIR of the image
122+
script:
123+
m: the main module name (e.g. my.module.trainer). When this option is used, the `script_args` are passed
124+
as the arguments to the main module). Invoking my module is useful when the relative/absolute path
125+
of the main script is unknown w.r.t the WORKDIR of the image. Use this option when it makes sense to
126+
invoke the main script via `python -m <MAIN.MODULE>`.
127+
image: the base docker image of the workspace, if workspace is disabled, then the image of the job
128+
name: ``{experimentname}/{runname}`` or ``{experimentname}/`` or ``/{runname}`` or ``{runname}``
129+
h: the type of host to run on (e.g. aws_p4d.24xlarge). Must be one of the registered named resources
130+
j: {nnodes}x{nproc_per_node}. For GPU hosts omitting nproc_per_node will infer it from the GPU count on the host
131+
env: environment variables to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
132+
max_retries: the number of scheduler retries allowed
133+
rdzv_port: the port on rank0's host to use for hosting the c10d store used for rendezvous.
134+
Only takes effect when running multi-node. When running single node, this parameter
135+
is ignored and a random free port is chosen.
136+
mounts: (for docker based runs only) mounts to mount into the worker environment/container
137+
(ex. type=<bind/volume>,src=/host,dst=/job[,readonly]).
138+
debug: whether to run with preset debug flags enabled
139+
140+
"""
141+
142+
if env is None:
143+
env = {}
144+
145+
return ddp(
146+
*args,
147+
script=script,
148+
m=m,
149+
image=image,
150+
name=name,
151+
h=h,
152+
j=str(StructuredJArgument.parse_from(h, j)),
153+
env=env,
154+
max_retries=max_retries,
155+
mounts=mounts,
156+
debug=debug,
157+
)
158+
159+
83160
def ddp(
84161
*script_args: str,
85162
script: Optional[str] = None,
86163
m: Optional[str] = None,
87164
image: str = torchx.IMAGE,
88-
name: Optional[str] = None,
165+
name: str = "/",
89166
h: Optional[str] = None,
90167
cpu: int = 2,
91168
gpu: int = 0,
@@ -114,7 +191,8 @@ def ddp(
114191
script: script or binary to run within the image
115192
m: the python module path to run
116193
image: image (e.g. docker)
117-
name: job name override (uses the script name if not specified)
194+
name: job name override in the following format: ``{experimentname}/{runname}`` or ``{experimentname}/`` or ``/{runname}`` or ``{runname}``.
195+
Uses the script or module name if ``{runname}`` not specified.
118196
cpu: number of cpus per replica
119197
gpu: number of gpus per replica
120198
memMB: cpu memory in MB per replica
@@ -138,14 +216,6 @@ def ddp(
138216
# nproc_per_node: number of processes on each node
139217
min_nnodes, max_nnodes, nproc_per_node, nnodes_rep = parse_nnodes(j)
140218

141-
if script:
142-
# script name/module no extension
143-
role_name = Path(script).stem
144-
elif m:
145-
role_name = m.rpartition(".")[2]
146-
else:
147-
raise ValueError("failed to compute role_name")
148-
149219
rdzv_backend = "c10d"
150220
if max_nnodes == 1:
151221
# using port 0 makes elastic chose a free random port which is ok
@@ -165,8 +235,16 @@ def ddp(
165235

166236
if env is None:
167237
env = {}
168-
env.setdefault("LOGLEVEL", os.getenv("LOGLEVEL", "WARNING"))
169238

239+
argname = StructuredNameArgument.parse_from(
240+
name=name,
241+
m=m,
242+
script=script,
243+
)
244+
245+
env["TORCHX_TRACKING_EXPERIMENT_NAME"] = argname.experiment_name
246+
247+
env.setdefault("LOGLEVEL", os.getenv("LOGLEVEL", "WARNING"))
170248
if debug:
171249
env.update(_TORCH_DEBUG_FLAGS)
172250

@@ -193,10 +271,10 @@ def ddp(
193271
cmd += ["-m", m]
194272
cmd += script_args
195273
return specs.AppDef(
196-
name=name or role_name,
274+
name=argname.run_name,
197275
roles=[
198276
specs.Role(
199-
name=role_name,
277+
name=get_role_name(script, m),
200278
image=image,
201279
min_replicas=min_nnodes,
202280
entrypoint="bash",
@@ -214,6 +292,17 @@ def ddp(
214292
)
215293

216294

295+
def get_role_name(script: Optional[str], m: Optional[str]) -> str:
296+
if script:
297+
# script name/module no extension
298+
role_name = Path(script).stem
299+
elif m:
300+
role_name = m.rpartition(".")[2]
301+
else:
302+
raise ValueError("failed to compute role_name")
303+
return role_name
304+
305+
217306
def _args_join(args: Iterable[str]) -> str:
218307
"""
219308
_args_join is like shlex.join but if the argument is wrapped in _noquote

0 commit comments

Comments
 (0)