Skip to content

Commit abf88d4

Browse files
committed
update atari
1 parent 7ccecd2 commit abf88d4

File tree

5 files changed

+24
-11
lines changed

5 files changed

+24
-11
lines changed

examples/atari/README.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,19 @@ Then install auto-rom via:
88
or:
99
```shell
1010
pip install autorom
11-
1211
AutoROM --accept-license
1312
```
1413

14+
or, if you can not download the ROMs, you can download them manually from [Google Drive](https://drive.google.com/file/d/1agerLX3fP2YqUCcAkMF7v_ZtABAOhlA7/view?usp=sharing).
15+
Then, you can install the ROMs via:
16+
```shell
17+
pip install autorom
18+
AutoROM --source-file <path-to-Roms.tar.gz>
19+
````
20+
21+
1522
## Usage
1623

1724
```shell
18-
python train_ppo.py --config atari_ppo.yaml
25+
python train_ppo.py
1926
```

examples/atari/atari_ppo.yaml

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,27 @@ seed: 0
22
lr: 2.5e-4
33
critic_lr: 2.5e-4
44
episode_length: 128
5-
ppo_epoch: 4
5+
gamma: 0.99
6+
ppo_epoch: 3
67
gain: 0.01
78
use_linear_lr_decay: true
89
use_share_model: true
910
entropy_coef: 0.01
1011
hidden_size: 512
11-
num_mini_batch: 4
12-
clip_param: 0.1
12+
num_mini_batch: 8
13+
clip_param: 0.2
1314
value_loss_coef: 0.5
15+
max_grad_norm: 10
16+
1417
run_dir: ./run_results/
15-
experiment_name: atari_ppo
18+
1619
log_interval: 1
1720
use_recurrent_policy: false
1821
use_valuenorm: true
1922
use_adv_normalize: true
23+
2024
wandb_entity: openrl-lab
25+
experiment_name: atari_ppo
2126

2227
vec_info_class:
2328
id: "EPS_RewardInfo"

examples/atari/train_ppo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,11 @@
4343

4444
def train():
4545
cfg_parser = create_config_parser()
46-
cfg = cfg_parser.parse_args()
46+
cfg = cfg_parser.parse_args(["--config", "atari_ppo.yaml"])
4747

4848
# create environment, set environment parallelism to 9
4949
env = make(
50-
"ALE/Pong-v5", env_num=9, cfg=cfg, asynchronous=True, env_wrappers=env_wrappers
50+
"ALE/Pong-v5", env_num=16, cfg=cfg, asynchronous=True, env_wrappers=env_wrappers
5151
)
5252

5353
# create the neural network
@@ -56,7 +56,7 @@ def train():
5656
env, cfg=cfg, device="cuda" if "macOS" not in get_system_info()["OS"] else "cpu"
5757
)
5858
# initialize the trainer
59-
agent = Agent(net, use_wandb=True)
59+
agent = Agent(net, use_wandb=True, project_name="Pong-v5")
6060
# start training, set total number of training steps to 20000
6161

6262
agent.train(total_time_steps=5000000)

openrl/configs/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -726,8 +726,8 @@ def create_config_parser():
726726
parser.add_argument(
727727
"--max_grad_norm",
728728
type=float,
729-
default=10.0,
730-
help="max norm of gradients (default: 0.5)",
729+
default=10,
730+
help="max norm of gradients (default: 10)",
731731
)
732732
parser.add_argument(
733733
"--use_gae",

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def get_install_requires() -> list:
4141
"mujoco",
4242
"tqdm",
4343
"Jinja2",
44+
"pettingzoo",
4445
]
4546

4647

0 commit comments

Comments
 (0)