Skip to content

Commit e8a3d77

Browse files
committed
8.19 更改无效动作惩罚机制
1 parent 723af0f commit e8a3d77

File tree

7 files changed

+116
-77
lines changed

7 files changed

+116
-77
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ __pycache__/
66
.idea/
77
offline_train_logs/
88
online_train_logs/
9+
python.exe.lnk

agent.py

+16-17
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,19 @@ def __init__(self, args: dict):
1414
self.epsilon = args["epsilon"] # PPO ε
1515
self.k_epochs = args["k_epochs"] # PPO 训练轮数
1616
self.entropy_coef = args["entropy_coef"]
17-
self.device = args["device"] # 运行设备
1817

1918
# 神经网络
2019
self.pai_set = {
21-
20: get_model("actor", "./model/non_maze.pth", 20).to(self.device),
22-
19: get_model("actor", "./model/maze.pth", 19).to(self.device),
23-
10: get_model("actor", "./model/non_maze1v1.pth", 10).to(self.device),
24-
9: get_model("actor", "./model/maze1v1.pth", 9).to(self.device)
20+
20: get_model("actor", "./model/non_maze.pth", 20).to(device),
21+
19: get_model("actor", "./model/maze.pth", 19).to(device),
22+
10: get_model("actor", "./model/non_maze1v1.pth", 10).to(device),
23+
9: get_model("actor", "./model/maze1v1.pth", 9).to(device)
2524
}
2625
self.v_set = {
27-
20: get_model("critic", "./model/non_maze_critic.pth", 20).to(self.device),
28-
19: get_model("critic", "./model/maze_critic.pth", 19).to(self.device),
29-
10: get_model("critic", "./model/non_maze1v1_critic.pth", 10).to(self.device),
30-
9: get_model("critic", "./model/maze1v1_critic.pth", 9).to(self.device)
26+
20: get_model("critic", "./model/non_maze_critic.pth", 20).to(device),
27+
19: get_model("critic", "./model/maze_critic.pth", 19).to(device),
28+
10: get_model("critic", "./model/non_maze1v1_critic.pth", 10).to(device),
29+
9: get_model("critic", "./model/maze1v1_critic.pth", 9).to(device)
3130
}
3231
self.pai = self.pai_set[20]
3332
self.v = self.v_set[20]
@@ -46,12 +45,12 @@ def learn(self, rep, step_t):
4645
"""
4746
s, a, a_log_prob, r, s_, done = rep.get_data()
4847
# 全部送进N卡
49-
s = s.to(self.device)
50-
a = a.to(self.device)
51-
a_log_prob = a_log_prob.to(self.device)
52-
r = r.to(self.device)
53-
s_ = s_.to(self.device)
54-
done = done.to(self.device)
48+
s = s.to(device)
49+
a = a.to(device)
50+
a_log_prob = a_log_prob.to(device)
51+
r = r.to(device)
52+
s_ = s_.to(device)
53+
done = done.to(device)
5554

5655
# 利用GAE计算优势函数
5756
adv = []
@@ -63,7 +62,7 @@ def learn(self, rep, step_t):
6362
for delta, d in zip(reversed(deltas.flatten()), reversed(done.flatten())):
6463
gae = delta + self.gamma * self.lamda * gae * (1.0 - d)
6564
adv.insert(0, gae)
66-
adv = torch.tensor(adv, dtype=torch.float).view(-1, 1).to(self.device)
65+
adv = torch.tensor(adv, dtype=torch.float).view(-1, 1).to(device)
6766
v_target = adv + vs
6867

6968
# 优势归一化
@@ -143,7 +142,7 @@ def warm_up(self):
143142
预热 因为神经网络第一次跑会比较慢
144143
:return:
145144
"""
146-
t = torch.zeros([1, 12, 20, 20]).to(self.device)
145+
t = torch.zeros([1, 12, 20, 20]).to(device)
147146
self.pai(t)
148147
self.v(t)
149148

const.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ class BlockType(object):
55
road = 0 # null, unshown null
66
obstacle = 1 # obstacle
77
mountain = 2 # mountain
8-
crown = 3 # crown
8+
crown = 300 # crown
99
city = 4 # empty-city, city
1010

1111

@@ -61,6 +61,7 @@ class Style(object):
6161

6262
dx = [0, -1, 0, 1]
6363
dy = [-1, 0, 1, 0]
64+
inf = 999999999
6465

6566

6667
class ActionTranslator(object):
@@ -109,3 +110,21 @@ def a_to_i(self, size: int, action: torch.Tensor) -> int:
109110
direction = i
110111
return self.__indexes[size][action[0][0]][action[0][1]][direction][action[0][4]]
111112

113+
def mask(self, obs, map_size):
114+
o = obs[0]
115+
mask_vec = torch.zeros([len(self.__actions[map_size])], dtype=torch.long)
116+
for _a in range(len(self.__actions[map_size])):
117+
act = self.__actions[map_size][_a][0]
118+
act -= 1
119+
act[4] += 1
120+
121+
if int(o[2][int(act[1]) - 1][int(act[0]) - 1]) != 0:
122+
# 不是自己的
123+
mask_vec[_a] = -inf
124+
else:
125+
mask_vec[_a] = 1
126+
return mask_vec.to(device)
127+
128+
129+
at = ActionTranslator()
130+
device = torch.device("cuda")

main.py

+16-14
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
from offsite_env import OffSiteEnv
66
from normalization import *
77
from replay_buffer import *
8+
from const import *
89

910

1011
def main(offline_train=True):
1112
env = OffSiteEnv() if offline_train else OnSiteEnv()
1213

1314
total_steps = 0 # 记录总步数
1415

15-
device = torch.device("cuda")
1616
args = {
1717
"batch_size": 50,
1818
"state_dim": None,
@@ -25,7 +25,6 @@ def main(offline_train=True):
2525
"k_epochs": 10,
2626
"entropy_coef": 0.01,
2727
"autosave_step": 107,
28-
"device": device
2928
}
3029

3130
agent = PPOAgent(args)
@@ -56,9 +55,12 @@ def save_model():
5655
reward_scaling.reset()
5756

5857
done = False
58+
total_reward = 0
59+
_step = total_steps
5960
while not done:
6061
a, a_log_prob = agent.predict(s)
6162
s_, r, done, _ = env.step(a)
63+
total_reward += r
6264

6365
env.render("human")
6466

@@ -71,14 +73,21 @@ def save_model():
7173

7274
# 缓存到达batch size的时候更新参数
7375
if len(replay_buffer) == args["batch_size"]:
74-
_t = threading.Thread(target=update_model)
75-
_t.start()
76+
_t1 = threading.Thread(target=update_model)
77+
_t1.start()
7678
replay_buffer.clear()
7779

7880
# 自动保存模型 batch_size和autosave step的最小公倍数尽量大 因为同时保存和更新比较耗时间
7981
if total_steps % args["autosave_step"] == 0:
80-
_t = threading.Thread(target=save_model)
81-
_t.start()
82+
_t2 = threading.Thread(target=save_model)
83+
_t2.start()
84+
85+
# 绘制reward曲线 代表学习效果
86+
if env.episode % 10 == 0:
87+
writer.add_scalar(f"offline_train_{env.mode}", total_reward, env.episode)
88+
89+
game_result = "won" if env.win_check() == 2 else "lost"
90+
print(f"game {env.episode}: bot " + game_result + f", total_reward={total_reward}, step={total_steps - _step}")
8291

8392
if env.quit_signal():
8493
break
@@ -89,12 +98,5 @@ def save_model():
8998

9099

91100
"""
92-
Traceback (most recent call last):
93-
File "D:/MyFiles/LearningBot/main.py", line 88, in <module>
94-
main()
95-
File "D:/MyFiles/LearningBot/main.py", line 83, in main
96-
if env.quit_signal():
97-
AttributeError: 'OffSiteEnv' object has no attribute 'quit_signal'
98-
99-
Process finished with exit code 1
101+
检查mask是否生效 现在似乎没有起作用
100102
"""

networks.py

+22-15
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from torch import nn
2+
from const import *
23

34

45
def orthogonal_init(layer, gain=1.0):
@@ -28,23 +29,26 @@ def __init__(self, size):
2829

2930
# 全连接层和softmax
3031
self.dense_in = 48 * size ** 2
31-
self.dense_out = 4 * ((size - 2) ** 2 + 3 * (size - 2) + 2) * 2
32+
self.dense_out = 4 * ((size - 2) ** 2 + 3 * (size - 2) + 2) * 2 + 1
3233
self.dense1 = nn.Linear(in_features=self.dense_in, out_features=self.dense_out)
3334
self.dense2 = nn.Linear(in_features=self.dense_out, out_features=self.dense_out)
34-
self.softmax = nn.Softmax(dim=0)
35+
self.softmax = nn.Softmax(dim=1)
3536

3637
# 激活函数
3738
self.activ_func = nn.Tanh()
3839

3940
# 正交初始化
40-
orthogonal_init(self.conv1)
41-
orthogonal_init(self.conv2)
42-
orthogonal_init(self.conv3)
43-
orthogonal_init(self.conv4)
44-
orthogonal_init(self.dense1)
45-
orthogonal_init(self.dense2, gain=0.01)
41+
# orthogonal_init(self.conv1)
42+
# orthogonal_init(self.conv2)
43+
# orthogonal_init(self.conv3)
44+
# orthogonal_init(self.conv4)
45+
# orthogonal_init(self.dense1)
46+
# orthogonal_init(self.dense2, gain=0.01)
4647

4748
def forward(self, x):
49+
# 先生成MASK表
50+
mask = at.mask(x, x.shape[2])
51+
4852
x = self.conv1(x)
4953
x = self.batch_norm1(x)
5054
x = self.activ_func(x)
@@ -67,6 +71,9 @@ def forward(self, x):
6771
x = self.activ_func(x)
6872

6973
x = self.dense2(x)
74+
75+
# softmax之前mask一下
76+
x *= mask
7077
x = self.softmax(x)
7178
return x
7279

@@ -87,20 +94,20 @@ def __init__(self, size):
8794

8895
# 全连接层
8996
self.dense_in = 48 * size ** 2
90-
self.dense_out = 4 * ((size - 2) ** 2 + 3 * (size - 2) + 2) * 2
97+
self.dense_out = 4 * ((size - 2) ** 2 + 3 * (size - 2) + 2) * 2 + 1
9198
self.dense1 = nn.Linear(in_features=self.dense_in, out_features=self.dense_out)
9299
self.dense2 = nn.Linear(in_features=self.dense_out, out_features=1)
93100

94101
# 激活函数
95102
self.activ_func = nn.Tanh()
96103

97104
# 正交初始化
98-
orthogonal_init(self.conv1)
99-
orthogonal_init(self.conv2)
100-
orthogonal_init(self.conv3)
101-
orthogonal_init(self.conv4)
102-
orthogonal_init(self.dense1)
103-
orthogonal_init(self.dense2, gain=0.01)
105+
# orthogonal_init(self.conv1)
106+
# orthogonal_init(self.conv2)
107+
# orthogonal_init(self.conv3)
108+
# orthogonal_init(self.conv4)
109+
# orthogonal_init(self.dense1)
110+
# orthogonal_init(self.dense2, gain=0.01)
104111

105112
def forward(self, x):
106113
x = self.conv1(x)

0 commit comments

Comments
 (0)