Skip to content

Commit 8167d24

Browse files
committed
1 parent e8a3d77 commit 8167d24

File tree

6 files changed

+62
-49
lines changed

6 files changed

+62
-49
lines changed

agent.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from torch.nn import functional as F
12
from torch.distributions import Categorical
23
from torch.utils.data import BatchSampler, SubsetRandomSampler
34
from utils import *
@@ -71,7 +72,8 @@ def learn(self, rep, step_t):
7172
# 参数更新k轮
7273
for _ in range(self.k_epochs):
7374
for index in BatchSampler(SubsetRandomSampler(range(self.batch_size)), self.batch_size, False):
74-
dist_now = Categorical(self.pai(s[index]))
75+
mask = at.mask(s[index], s[index].shape[2])
76+
dist_now = Categorical(mask * self.pai.softmax(self.pai(s[index])))
7577
dist_entropy = dist_now.entropy().view(-1, 1) # shape(batch_size x 1)
7678
a_log_prob_now = dist_now.log_prob(a[index].squeeze()).view(-1, 1) # shape(batch_size x 1)
7779

@@ -123,7 +125,9 @@ def predict(self, observation):
123125
:return: 2 tensors: action, ln(p(a_t|s_t))
124126
"""
125127
with torch.no_grad():
126-
action_p = Categorical(self.pai(observation))
128+
mask = at.mask(observation, observation.shape[2])
129+
act_ = self.pai(observation) * mask
130+
action_p = Categorical(self.pai.softmax(act_))
127131
action = action_p.sample()
128132
a_log_prob = action_p.log_prob(action)
129133
return action, a_log_prob

const.py

+24-9
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class Style(object):
6161

6262
dx = [0, -1, 0, 1]
6363
dy = [-1, 0, 1, 0]
64-
inf = 999999999
64+
inf = 999999999.0
6565

6666

6767
class ActionTranslator(object):
@@ -95,6 +95,13 @@ def _generate(self, size):
9595
continue
9696
index[i][j][k][0] = len(action)
9797
action.append(torch.Tensor([[i, j, tgx, tgy, 0]]))
98+
for i in range(1, size + 1):
99+
for j in range(1, size + 1):
100+
for k in range(4):
101+
tgx = i + dx[k]
102+
tgy = j + dy[k]
103+
if tgx < 1 or tgx > size or tgy < 1 or tgy > size:
104+
continue
98105
index[i][j][k][1] = len(action)
99106
action.append(torch.Tensor([[i, j, tgx, tgy, 1]]))
100107

@@ -113,18 +120,26 @@ def a_to_i(self, size: int, action: torch.Tensor) -> int:
113120
def mask(self, obs, map_size):
114121
o = obs[0]
115122
mask_vec = torch.zeros([len(self.__actions[map_size])], dtype=torch.long)
116-
for _a in range(len(self.__actions[map_size])):
117-
act = self.__actions[map_size][_a][0]
118-
act -= 1
119-
act[4] += 1
120-
121-
if int(o[2][int(act[1]) - 1][int(act[0]) - 1]) != 0:
123+
mask_vec[0] = 1.0
124+
for _a in range(1, len(self.__actions[map_size])):
125+
act = self.__actions[map_size][_a][0].long().tolist()
126+
if int(o[10][act[1] - 1][act[0] - 1]) != 0:
122127
# 不是自己的
123128
mask_vec[_a] = -inf
124129
else:
125-
mask_vec[_a] = 1
130+
mask_vec[_a] = 1.0
126131
return mask_vec.to(device)
127132

128133

129134
at = ActionTranslator()
130-
device = torch.device("cuda")
135+
if torch.cuda.is_available():
136+
device = torch.device("cuda")
137+
else:
138+
device = torch.device("cpu")
139+
140+
141+
def debug_output_mask(_mask):
142+
print("[", end='')
143+
for _ in _mask:
144+
print(f"{_},", end=' ')
145+
print("]")

main.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -46,25 +46,28 @@ def save_model():
4646
agent.change_network(env.map_size)
4747

4848
# 初始化一些用于归一化的类
49-
state_norm = Normalization(shape=args["state_dim"]) # Trick 2:state normalization
49+
# state_norm = Normalization(shape=args["state_dim"]) # Trick 2:state normalization
5050
reward_scaling = RewardScaling(shape=1, gamma=args["gamma"])
5151

5252
replay_buffer = ReplayBuffer(args)
5353

54-
s = state_norm(s).to(device)
54+
# s = state_norm(s).to(device)
55+
s = s.to(device)
5556
reward_scaling.reset()
5657

5758
done = False
5859
total_reward = 0
5960
_step = total_steps
61+
render_mode = "machine"
6062
while not done:
6163
a, a_log_prob = agent.predict(s)
6264
s_, r, done, _ = env.step(a)
6365
total_reward += r
6466

65-
env.render("human")
67+
env.render(render_mode)
6668

67-
s_ = state_norm(s_).to(device)
69+
# s_ = state_norm(s_).to(device)
70+
s_ = s_.to(device)
6871
r = reward_scaling(r)
6972

7073
replay_buffer.store(s, a, a_log_prob, r, s_, done)
@@ -82,6 +85,10 @@ def save_model():
8285
_t2 = threading.Thread(target=save_model)
8386
_t2.start()
8487

88+
# 手动特判
89+
if r > 0:
90+
render_mode = "human"
91+
8592
# 绘制reward曲线 代表学习效果
8693
if env.episode % 10 == 0:
8794
writer.add_scalar(f"offline_train_{env.mode}", total_reward, env.episode)
@@ -95,8 +102,3 @@ def save_model():
95102

96103
if __name__ == '__main__':
97104
main()
98-
99-
100-
"""
101-
检查mask是否生效 现在似乎没有起作用
102-
"""

networks.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import torch
12
from torch import nn
2-
from const import *
3+
from torch.nn import functional as F
34

45

56
def orthogonal_init(layer, gain=1.0):
@@ -46,9 +47,6 @@ def __init__(self, size):
4647
# orthogonal_init(self.dense2, gain=0.01)
4748

4849
def forward(self, x):
49-
# 先生成MASK表
50-
mask = at.mask(x, x.shape[2])
51-
5250
x = self.conv1(x)
5351
x = self.batch_norm1(x)
5452
x = self.activ_func(x)
@@ -71,11 +69,10 @@ def forward(self, x):
7169
x = self.activ_func(x)
7270

7371
x = self.dense2(x)
72+
return self.my_PReLU(x)
7473

75-
# softmax之前mask一下
76-
x *= mask
77-
x = self.softmax(x)
78-
return x
74+
def my_PReLU(self, x):
75+
return torch.max(x, torch.FloatTensor([0.0]).cuda()) - 0.05 * torch.min(x, torch.FloatTensor([0.0]).cuda())
7976

8077

8178
class Critic(nn.Module):

offsite_env.py

+13-16
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,6 @@ def __init__(self, mode="non_maze"):
3333
self.internal_bots = {}
3434
self.internal_bots_num = 0
3535
self.internal_bots_color = []
36-
for i in range(1, 9):
37-
if i != self.learningbot_color:
38-
self.internal_bots_color.append(i)
39-
self.internal_bots[i] = ibot.Game(i, self.get_view_of, self.bot_action_upd)
4036
self.actions_now = [[], [], [], [], [], [], [], []]
4137

4238
def reset(self):
@@ -56,19 +52,26 @@ def reset(self):
5652
self.player_num = random.randint(3, 8)
5753
self.episode += 1
5854
self.gen_map(self.player_num)
55+
self.round = 0
5956

6057
# 初始化shown
6158
self.shown = torch.zeros([9, self.map_size, self.map_size])
6259

6360
# 初始化内部bot
6461
self.internal_bots_num = self.player_num - 1
62+
for i in range(1, 9):
63+
if i != self.learningbot_color:
64+
self.internal_bots_color.append(i)
65+
self.internal_bots[i] = ibot.Game(i, self.get_view_of, self.bot_action_upd)
6566

6667
# 处理observation
68+
self.obs_history.queue.clear()
6769
self.obs_history.put(torch.zeros([4, self.map_size, self.map_size]))
6870
self.obs_history.put(torch.zeros([4, self.map_size, self.map_size]))
6971
self.obs_history.put(copy.copy(self.get_view_of(self.learningbot_color)))
7072

7173
# 先推一个空action进去 而且action_history还是只存list为妙 不然会有莫名其妙的错误
74+
self.action_history.queue.clear()
7275
self.action_history.put([-1, -1, -1, -1, -1])
7376
return self.gen_observation()
7477

@@ -111,12 +114,9 @@ def step(self, action: torch.Tensor):
111114
# 如果动作为空
112115
if last_move[0] < 0 or self.round == 1:
113116
return obs, reward, False, {}
114-
# 无效移动扣大分
115-
if int(last_obs[2][last_move[1] - 1][last_move[0] - 1]) != self._get_colormark(self.learningbot_color):
116-
reward -= 100
117117
# 撞山扣一点
118118
if last_obs[1][last_move[3] - 1][last_move[2] - 1] == BlockType.mountain:
119-
reward -= 2
119+
reward -= 0.3
120120
# 撞塔扣分
121121
if self.map[1][last_move[3] - 1][last_move[2] - 1] == BlockType.city:
122122
if self.map[2][last_move[3] - 1][last_move[2] - 1] != self.learningbot_color:
@@ -190,19 +190,16 @@ def execute_actions(self, action: torch.Tensor):
190190
self.combine((cur_action[1], cur_action[0]), (cur_action[3], cur_action[2]), mov_troop)
191191

192192
# 处理LearningBot动作
193-
act = at.i_to_a(self.map_size, int(action))[0].long()
194-
act -= 1
195-
act[4] += 1
196-
# is_available = "available" if self.map[2][act[1]][act[0]] == self.learningbot_color else "unavailable"
197-
# print(f"<{is_available}>: {act.tolist()}")
193+
act = at.i_to_a(self.map_size, int(action))[0].long().tolist()
194+
print(act)
198195
# 检查动作是否合法 act中可能会存在-1 代表空回合
199-
if act[0] >= 0 and self.map[2][act[1]][act[0]] == self.learningbot_color:
200-
f_amount = int(self.map[0][act[1]][act[0]])
196+
if act[0] - 1 >= 0 and self.map[2][act[1] - 1][act[0] - 1] == self.learningbot_color:
197+
f_amount = int(self.map[0][act[1] - 1][act[0] - 1])
201198
if act[4] == 1:
202199
mov_troop = math.ceil((f_amount + 0.5) / 2) - 1
203200
else:
204201
mov_troop = f_amount - 1
205-
self.combine((act[1], act[0]), (act[3], act[2]), mov_troop)
202+
self.combine((act[1] - 1, act[0] - 1), (act[3] - 1, act[2] - 1), mov_troop)
206203

207204
def combine(self, b1: tuple, b2: tuple, cnt):
208205
"""

onsite_env.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -98,18 +98,15 @@ def step(self, action: torch.Tensor):
9898
# 计算上一步的奖励
9999
_dirx = [0, -1, 0, 1, 1, -1, 1, -1]
100100
_diry = [-1, 0, 1, 0, 1, -1, -1, 1]
101-
last_move = self.action_history.queue[-1]
101+
last_move = self.action_history.queue[-1].long().tolist()
102102
last_map = self.map_history.queue[-1]
103103
# 保存action
104104
if self.action_history.qsize() == 3:
105105
self.action_history.get()
106-
self.action_history.put(copy.copy(action[0].long()))
106+
self.action_history.put(copy.copy(at.i_to_a(self.map_size, int(action[0].long()))[0]))
107107
# 如果动作为空
108108
if last_move[0] < 0:
109109
return self.observation, reward, False, {}
110-
# 无效移动扣大分
111-
if last_map[2][last_move[1] - 1][last_move[0] - 1] != self._get_colormark(self.self_color):
112-
reward -= 100
113110
# 撞山扣一点
114111
if self.map[1][last_move[3] - 1][last_move[2] - 1] == BlockType.mountain:
115112
reward -= 10
@@ -276,6 +273,7 @@ def win_check(self) -> int:
276273
"""
277274
try:
278275
t = self.driver.find_element(By.ID, "swal2-content")
276+
self.driver.find_element(By.CSS_SELECTOR, "div.swal2-actions > button.swal2-confirm.swal2-styled")
279277
if t.text.strip() == settings.bot_name + "赢了":
280278
return 2
281279
except NoSuchElementException:

0 commit comments

Comments
 (0)