-
Notifications
You must be signed in to change notification settings - Fork 376
/
Copy pathupdate_command_args.py
124 lines (108 loc) · 4.23 KB
/
update_command_args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
This script is used to add or update arguments in a shell script. For example,
```bash
python update_command_args.py scripts/train/tulu3/grpo_fast_8b.sh \
--cluster ai2/augusta-google-1 \
--priority normal \
--dataset_mixer_list allenai/RLVR-GSM 1.0 allenai/RLVR-MATH 1.0 \
--image costah/open_instruct_dev0320_11 | uv run bash
```
would replace the `--cluster`, `--priority`, `--image` arguments in the script with the ones specified.
"""
import sys
import argparse
from typing import List
def read_shell_script(filename: str) -> str:
with open(filename, 'r') as f:
return f.read()
def modify_command(content: str, new_args: List[str]) -> str:
split_content = content.split(" ")
new_content = []
flag_args = []
flag = None
for _, part in enumerate(split_content):
if flag is None:
if not part.startswith('--'):
new_content.append(part)
else:
flag = part.split('--')[1]
flag_args.append(part)
else:
if not part.startswith('--'):
flag_args.append(part)
else:
if flag in new_args:
new_content.append(f"--{flag}")
new_args_values = new_args[flag]
if isinstance(new_args_values, list):
new_content.extend(new_args_values)
else:
new_content.append(new_args_values)
# hack the convention to make the format nicer
new_content.extend(["\\\n", "", "", ""])
del new_args[flag]
else:
new_content.append(f"--{flag}")
if isinstance(flag_args, list):
new_content.extend(flag_args)
else:
new_content.append(flag_args)
flag = part.split('--')[1]
flag_args = []
if flag is not None:
new_content.append(f"--{flag}")
if isinstance(flag_args, list):
new_content.extend(flag_args)
else:
new_content.append(flag_args)
# add the remaining args
for flag, value in new_args.items():
new_content.append(f"--{flag}")
if isinstance(value, list):
new_content.extend(value)
else:
new_content.append(value)
new_content.extend(["\\\n", "", "", ""])
return " ".join(new_content)
def main():
if len(sys.argv) < 2:
print("Usage: python launch.py <shell_script> [--arg value ...]")
sys.exit(1)
script_file = sys.argv[1]
# Parse remaining arguments as key-value pairs
parser = argparse.ArgumentParser()
num_values = 0
last_arg = None
for i in range(2, len(sys.argv)):
if sys.argv[i].startswith('--'):
arg = sys.argv[i].lstrip('-')
nargs = "+" if num_values % 2 == 0 else "?"
if last_arg is not None:
parser.add_argument(f"--{last_arg}", nargs=nargs)
last_arg = arg
num_values = 0
else:
num_values += 1
nargs = "+" if num_values % 2 == 0 else "?"
if last_arg is not None:
parser.add_argument(f"--{last_arg}", nargs=nargs)
args = parser.parse_args(sys.argv[2:])
new_args = {k: v for k, v in vars(args).items() if v is not None}
# Read and modify the script
content = read_shell_script(script_file)
modified_content = modify_command(content, new_args)
print(modified_content)
def test_modify_command():
content = "python train.py --dataset_mixer_list xxx 1.0 --cluster ai2/augusta-google-1 --priority normal"
new_args = {
"dataset_mixer_list": ["xxx", "1.0"],
"cluster": "ai2/augusta-google-1",
"priority": "normal",
"image": "costah/open_instruct_dev0320_11"
}
modified_content = modify_command(content, new_args)
normalized_content = " ".join(modified_content.replace("\\\n", "").split())
assert normalized_content == "python train.py --dataset_mixer_list xxx 1.0 --cluster ai2/augusta-google-1 --priority normal --priority normal --image costah/open_instruct_dev0320_11"
if __name__ == "__main__":
test_modify_command()
main()