62
62
63
63
import torchx
64
64
import torchx .specs as specs
65
+ from torchx .components .structured_arg import StructuredJArgument , StructuredNameArgument
65
66
from torchx .specs import macros
66
67
67
68
_TORCH_DEBUG_FLAGS : Dict [str , str ] = {
80
81
"""
81
82
82
83
84
+ def spmd (
85
+ * args : str ,
86
+ script : Optional [str ] = None ,
87
+ m : Optional [str ] = None ,
88
+ image : str = torchx .IMAGE ,
89
+ name : str = "/" ,
90
+ h : str = "gpu.small" ,
91
+ j : str = "1x1" ,
92
+ env : Optional [Dict [str , str ]] = None ,
93
+ max_retries : int = 0 ,
94
+ mounts : Optional [List [str ]] = None ,
95
+ debug : bool = False ,
96
+ ) -> specs .AppDef :
97
+ """
98
+ Usage (by script): torchx run spmd -j 2x8 -h aws_p4d.24xlarge --name my_experiment/trial_1 --script path/to/my/trainer.py -foo bar
99
+
100
+ Usage (by module): torchx run spmd -j 2x8 -h aws_p4d.24xlarge --name my_experiment/trial_1 -m path.to.my.trainer -foo bar
101
+
102
+ Usage (infer GPU count): torchx run spmd -j 2 -h p4d.24xlarge ... (same as -j 2x8)
103
+
104
+ Creates a torchx.specs.AppDef (Job Definition) for a Single-Process-Multiple-Data (SPMD)
105
+ style application. See: https://en.wikipedia.org/wiki/Single_program,_multiple_data.
106
+
107
+ SPMD launches `n x m` (set via the `-j nxm` option) copies of the same program,
108
+ where `n` is the number of nodes (hosts) and `m` is the number of processes on each node.
109
+
110
+ If you have a distributed PyTorch script (DDP, FSDP, RPC) use this component to launch
111
+ the distributed application. You can also use `-j 1x1` to launch a single process application
112
+ which would be equivalent to launching with regular `python` except that your application
113
+ can safely call `torch.distributed.init_process_group(backend)`.
114
+
115
+ Note: For multi-node distributed runs, the hosts MUST have a network route to each other
116
+ AND port 29500 should be open on all hosts. Please check your security group settings.
117
+
118
+
119
+ Args:
120
+ args: the arguments to the main module or script (e.g. my/trainer.py -foo bar)
121
+ (for docker based runs) the script path must be relative to the WORKDIR of the image
122
+ script:
123
+ m: the main module name (e.g. my.module.trainer). When this option is used, the `script_args` are passed
124
+ as the arguments to the main module). Invoking my module is useful when the relative/absolute path
125
+ of the main script is unknown w.r.t the WORKDIR of the image. Use this option when it makes sense to
126
+ invoke the main script via `python -m <MAIN.MODULE>`.
127
+ image: the base docker image of the workspace, if workspace is disabled, then the image of the job
128
+ name: ``{experimentname}/{runname}`` or ``{experimentname}/`` or ``/{runname}`` or ``{runname}``
129
+ h: the type of host to run on (e.g. aws_p4d.24xlarge). Must be one of the registered named resources
130
+ j: {nnodes}x{nproc_per_node}. For GPU hosts omitting nproc_per_node will infer it from the GPU count on the host
131
+ env: environment variables to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
132
+ max_retries: the number of scheduler retries allowed
133
+ rdzv_port: the port on rank0's host to use for hosting the c10d store used for rendezvous.
134
+ Only takes effect when running multi-node. When running single node, this parameter
135
+ is ignored and a random free port is chosen.
136
+ mounts: (for docker based runs only) mounts to mount into the worker environment/container
137
+ (ex. type=<bind/volume>,src=/host,dst=/job[,readonly]).
138
+ debug: whether to run with preset debug flags enabled
139
+
140
+ """
141
+
142
+ if env is None :
143
+ env = {}
144
+
145
+ return ddp (
146
+ * args ,
147
+ script = script ,
148
+ m = m ,
149
+ image = image ,
150
+ name = name ,
151
+ h = h ,
152
+ j = str (StructuredJArgument .parse_from (h , j )),
153
+ env = env ,
154
+ max_retries = max_retries ,
155
+ mounts = mounts ,
156
+ debug = debug ,
157
+ )
158
+
159
+
83
160
def ddp (
84
161
* script_args : str ,
85
162
script : Optional [str ] = None ,
86
163
m : Optional [str ] = None ,
87
164
image : str = torchx .IMAGE ,
88
- name : Optional [ str ] = None ,
165
+ name : str = "/" ,
89
166
h : Optional [str ] = None ,
90
167
cpu : int = 2 ,
91
168
gpu : int = 0 ,
@@ -114,7 +191,8 @@ def ddp(
114
191
script: script or binary to run within the image
115
192
m: the python module path to run
116
193
image: image (e.g. docker)
117
- name: job name override (uses the script name if not specified)
194
+ name: job name override in the following format: ``{experimentname}/{runname}`` or ``{experimentname}/`` or ``/{runname}`` or ``{runname}``.
195
+ Uses the script or module name if ``{runname}`` not specified.
118
196
cpu: number of cpus per replica
119
197
gpu: number of gpus per replica
120
198
memMB: cpu memory in MB per replica
@@ -138,14 +216,6 @@ def ddp(
138
216
# nproc_per_node: number of processes on each node
139
217
min_nnodes , max_nnodes , nproc_per_node , nnodes_rep = parse_nnodes (j )
140
218
141
- if script :
142
- # script name/module no extension
143
- role_name = Path (script ).stem
144
- elif m :
145
- role_name = m .rpartition ("." )[2 ]
146
- else :
147
- raise ValueError ("failed to compute role_name" )
148
-
149
219
rdzv_backend = "c10d"
150
220
if max_nnodes == 1 :
151
221
# using port 0 makes elastic chose a free random port which is ok
@@ -165,8 +235,16 @@ def ddp(
165
235
166
236
if env is None :
167
237
env = {}
168
- env .setdefault ("LOGLEVEL" , os .getenv ("LOGLEVEL" , "WARNING" ))
169
238
239
+ argname = StructuredNameArgument .parse_from (
240
+ name = name ,
241
+ m = m ,
242
+ script = script ,
243
+ )
244
+
245
+ env ["TORCHX_TRACKING_EXPERIMENT_NAME" ] = argname .experiment_name
246
+
247
+ env .setdefault ("LOGLEVEL" , os .getenv ("LOGLEVEL" , "WARNING" ))
170
248
if debug :
171
249
env .update (_TORCH_DEBUG_FLAGS )
172
250
@@ -193,10 +271,10 @@ def ddp(
193
271
cmd += ["-m" , m ]
194
272
cmd += script_args
195
273
return specs .AppDef (
196
- name = name or role_name ,
274
+ name = argname . run_name ,
197
275
roles = [
198
276
specs .Role (
199
- name = role_name ,
277
+ name = get_role_name ( script , m ) ,
200
278
image = image ,
201
279
min_replicas = min_nnodes ,
202
280
entrypoint = "bash" ,
@@ -214,6 +292,17 @@ def ddp(
214
292
)
215
293
216
294
295
+ def get_role_name (script : Optional [str ], m : Optional [str ]) -> str :
296
+ if script :
297
+ # script name/module no extension
298
+ role_name = Path (script ).stem
299
+ elif m :
300
+ role_name = m .rpartition ("." )[2 ]
301
+ else :
302
+ raise ValueError ("failed to compute role_name" )
303
+ return role_name
304
+
305
+
217
306
def _args_join (args : Iterable [str ]) -> str :
218
307
"""
219
308
_args_join is like shlex.join but if the argument is wrapped in _noquote
0 commit comments