Skip to content

Commit 723af0f

Browse files
committed
8.16 继续调
1 parent c46af30 commit 723af0f

File tree

7 files changed

+91
-38
lines changed

7 files changed

+91
-38
lines changed

agent.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -45,20 +45,25 @@ def learn(self, rep, step_t):
4545
:return:
4646
"""
4747
s, a, a_log_prob, r, s_, done = rep.get_data()
48+
# 全部送进N卡
49+
s = s.to(self.device)
50+
a = a.to(self.device)
51+
a_log_prob = a_log_prob.to(self.device)
52+
r = r.to(self.device)
53+
s_ = s_.to(self.device)
54+
done = done.to(self.device)
4855

4956
# 利用GAE计算优势函数
5057
adv = []
5158
gae = 0
52-
s = s.to(self.device)
53-
s_ = s_.to(self.device)
5459
with torch.no_grad(): # 不需要梯度
5560
vs = self.v(s)
5661
vs_ = self.v(s_)
5762
deltas = r + self.gamma * (1.0 - done) * vs_ - vs
58-
for delta, d in zip(reversed(deltas.flatten().numpy()), reversed(done.flatten().numpy())):
63+
for delta, d in zip(reversed(deltas.flatten()), reversed(done.flatten())):
5964
gae = delta + self.gamma * self.lamda * gae * (1.0 - d)
6065
adv.insert(0, gae)
61-
adv = torch.tensor(adv, dtype=torch.float).view(-1, 1)
66+
adv = torch.tensor(adv, dtype=torch.float).view(-1, 1).to(self.device)
6267
v_target = adv + vs
6368

6469
# 优势归一化

bot_div/game.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def get_map_from_env(self):
4242
game_map = self.get_tensor_map(self.bot_color)
4343
map_size = game_map.shape[1]
4444
game_map = game_map.long().tolist()
45+
self.mp.resize(map_size)
4546
for i in range(1, map_size + 1):
4647
for j in range(1, map_size + 1):
4748
if game_map[1][i - 1][j - 1] == BlockType.city:
@@ -242,7 +243,7 @@ def gather_army_to(self, x, y, method='rectangle'): # 向(x, y)聚兵
242243
ans_top_right = (ans_top_left[0], ans_bottom_right[1])
243244
ans_bottom_left = (ans_bottom_right[0], ans_top_left[1])
244245
ans = [ans_top_left, ans_top_right, ans_bottom_left, ans_bottom_right]
245-
print(ans, best_sum)
246+
# print(ans, best_sum)
246247
tmp = []
247248
target_node = (x, y)
248249
for i in ans:
@@ -286,7 +287,7 @@ def gather_army_to(self, x, y, method='rectangle'): # 向(x, y)聚兵
286287
min_y = min(ans_top_left[1], ans_top_right[1])
287288
max_y = max(ans_top_left[1], ans_top_right[1])
288289
while cx != end_node[0] or cy != end_node[1]: # 蛇形遍历
289-
print(cx, cy)
290+
# print(cx, cy)
290291
px = cx
291292
py = cy
292293
if cur_dir:
@@ -349,7 +350,7 @@ def flush_movements(self): # 更新移动
349350
self.cur_y -= 1
350351
elif self.movements[0] == 'D':
351352
self.cur_y += 1
352-
act = [x_old, y_old, self.cur_x, self.cur_y, is_half]
353+
act = [y_old, x_old, self.cur_y, self.cur_x, is_half]
353354
self.send_action(self.bot_color, act)
354355
self.movements.pop(0)
355356
self.get_map_from_env()

bot_div/map.py

+10
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,16 @@ def __init__(self, amount=0, belong=0, type='land'):
1515
self.type = type # land city general unknown mountain empty empty-city
1616
self.cost = 0
1717

18+
def __str__(self):
19+
return '{' +\
20+
f"""
21+
'amount': {self.amount},
22+
'belong': {self.belong},
23+
'type': {self.type},
24+
'cost': {self.cost}
25+
"""\
26+
+ '}'
27+
1828

1929
def dist_node(a, b):
2030
return dist(a[0], a[1], b[0], b[1])

main.py

+24-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import threading
12
from torch.utils.tensorboard import SummaryWriter
23
from agent import PPOAgent
34
from onsite_env import OnSiteEnv
@@ -13,7 +14,7 @@ def main(offline_train=True):
1314

1415
device = torch.device("cuda")
1516
args = {
16-
"batch_size": 100,
17+
"batch_size": 50,
1718
"state_dim": None,
1819
"action_dim": 5,
1920
"lr_a": 0.01,
@@ -30,6 +31,12 @@ def main(offline_train=True):
3031
agent = PPOAgent(args)
3132
agent.warm_up()
3233

34+
def update_model():
35+
agent.learn(replay_buffer, total_steps)
36+
37+
def save_model():
38+
agent.save()
39+
3340
# 绘图器
3441
writer = SummaryWriter("offline_train_logs" if offline_train else "online_train_logs")
3542

@@ -64,16 +71,30 @@ def main(offline_train=True):
6471

6572
# 缓存到达batch size的时候更新参数
6673
if len(replay_buffer) == args["batch_size"]:
67-
agent.learn(replay_buffer, total_steps)
74+
_t = threading.Thread(target=update_model)
75+
_t.start()
6876
replay_buffer.clear()
6977

7078
# 自动保存模型 batch_size和autosave step的最小公倍数尽量大 因为同时保存和更新比较耗时间
7179
if total_steps % args["autosave_step"] == 0:
72-
agent.save()
80+
_t = threading.Thread(target=save_model)
81+
_t.start()
7382

7483
if env.quit_signal():
7584
break
7685

7786

7887
if __name__ == '__main__':
7988
main()
89+
90+
91+
"""
92+
Traceback (most recent call last):
93+
File "D:/MyFiles/LearningBot/main.py", line 88, in <module>
94+
main()
95+
File "D:/MyFiles/LearningBot/main.py", line 83, in main
96+
if env.quit_signal():
97+
AttributeError: 'OffSiteEnv' object has no attribute 'quit_signal'
98+
99+
Process finished with exit code 1
100+
"""

offsite_env.py

+28-20
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def reset(self):
7272
def step(self, action: torch.Tensor):
7373
"""
7474
执行一步
75-
:param action: movement => tensor([[x1, y1, x2, y2, is_half]])
75+
:param action: movement => tensor([[x1, y1, x2, y2, is_half]]) 注意x,y和i,j正好相反
7676
:return: observation (Tensor), reward (float), done (bool), info (dict)
7777
"""
7878
# 运行
@@ -106,16 +106,19 @@ def step(self, action: torch.Tensor):
106106
if last_move[0] < 0:
107107
return obs, reward, False, {}
108108
# 无效移动扣大分
109-
if last_obs[2][last_move[0] - 1][last_move[1] - 1] != self._get_colormark(self.learningbot_color):
109+
if last_obs[2][last_move[1] - 1][last_move[0] - 1] != self._get_colormark(self.learningbot_color):
110110
reward -= 100
111+
# 撞山扣一点
112+
if last_obs[1][last_move[3] - 1][last_move[2] - 1] == BlockType.mountain:
113+
reward -= 10
111114
# 撞塔扣分
112-
if self.map[1][last_move[2] - 1][last_move[3] - 1] == BlockType.city:
113-
if self.map[2][last_move[2] - 1][last_move[3] - 1] != self.learningbot_color:
115+
if self.map[1][last_move[3] - 1][last_move[2] - 1] == BlockType.city:
116+
if self.map[2][last_move[3] - 1][last_move[2] - 1] != self.learningbot_color:
114117
reward -= 10
115118
# 探索新领地加分 注意 不是占领
116119
for i in range(8):
117-
t_x = last_move[2] - 1 + _dirx[i]
118-
t_y = last_move[3] - 1 + _diry[i]
120+
t_x = last_move[3] - 1 + _dirx[i]
121+
t_y = last_move[2] - 1 + _diry[i]
119122
if t_x < 0 or t_x >= self.map_size or t_y < 0 or t_y >= self.map_size:
120123
continue
121124
if self.map[3][t_x][t_y] - last_obs[3][t_x][t_y] == 1:
@@ -150,13 +153,17 @@ def render(self, mode="human"):
150153
def execute_actions(self, action: torch.Tensor):
151154
"""
152155
执行动作
153-
:param action: LearningBot的动作 内置bot动作会存到类变量里边 不需要传参
156+
:param action: LearningBot的动作 内置bot动作会存到类变量里边 不需要传参 注意x,y和i,j正好相反
154157
:return:
155158
"""
156159
# 处理内置bot动作
157160
for i in range(self.internal_bots_num):
158161
cur_color = self.internal_bots_color[i]
159-
self.internal_bots[cur_color].bot_move()
162+
try:
163+
# 有的时候会有莫名其妙的报错 懒得调了 反正这个Bot很弱 也不差这一个回合 主要训练还是得靠和人打
164+
self.internal_bots[cur_color].bot_move()
165+
except Exception:
166+
continue
160167
if not self.actions_now[cur_color]:
161168
print(f"bot {cur_color} empty move")
162169
continue
@@ -173,26 +180,26 @@ def execute_actions(self, action: torch.Tensor):
173180

174181
print(f"internal bot color {cur_color}: {cur_action}")
175182
# 检查动作是否合法
176-
if cur_action[0] >= 0 and self.map[2][cur_action[0]][cur_action[1]] == cur_color:
177-
f_amount = int(self.map[0][cur_action[0]][cur_action[1]])
183+
if self.map[2][cur_action[1]][cur_action[0]] == cur_color:
184+
f_amount = int(self.map[0][cur_action[1]][cur_action[0]])
178185
if cur_action[4] == 1:
179186
mov_troop = math.ceil((f_amount + 0.5) / 2) - 1
180187
else:
181188
mov_troop = f_amount - 1
182-
self.combine((cur_action[0], cur_action[1]), (cur_action[2], cur_action[3]), mov_troop)
189+
self.combine((cur_action[1], cur_action[0]), (cur_action[3], cur_action[2]), mov_troop)
183190

184191
# 处理LearningBot动作
185192
act = self.at.i_to_a(self.map_size, int(action))[0].long()
186193
act -= 1
187194
act[4] += 1
188-
# 检查动作是否合法
189-
if act[0] >= 0 and self.map[2][act[0]][act[1]] == self.learningbot_color:
190-
f_amount = int(self.map[0][act[0]][act[1]])
195+
# 检查动作是否合法 act中可能会存在-1 代表空回合
196+
if act[0] >= 0 and self.map[2][act[1]][act[0]] == self.learningbot_color:
197+
f_amount = int(self.map[0][act[1]][act[0]])
191198
if act[4] == 1:
192199
mov_troop = math.ceil((f_amount + 0.5) / 2) - 1
193200
else:
194201
mov_troop = f_amount - 1
195-
self.combine((act[0], act[1]), (act[2], act[3]), mov_troop)
202+
self.combine((act[1], act[0]), (act[3], act[2]), mov_troop)
196203

197204
def combine(self, b1: tuple, b2: tuple, cnt):
198205
"""
@@ -228,10 +235,9 @@ def combine(self, b1: tuple, b2: tuple, cnt):
228235
tcolor = t["color"]
229236
for i in range(self.map_size):
230237
for j in range(self.map_size):
231-
if self.map[2][i][j] == tcolor:
238+
if int(self.map[2][i][j]) == tcolor:
232239
self.map[2][i][j] = f["color"]
233-
if self.map[2][i][j] == BlockType.crown:
234-
self.map[2][i][j] = BlockType.city
240+
t["type"] = BlockType.city
235241
t["color"] = f["color"]
236242
t["amount"] = -t["amount"]
237243

@@ -284,7 +290,7 @@ def get_view_of(self, color):
284290
# 如果这个玩家现在看不到这一格
285291
if color != self.learningbot_color or int(self.shown[color][i][j]) == 0:
286292
# 如果是LearningBot 那就帮它保留视野吧(●'◡'●)
287-
if self.map[1][i][j] == BlockType.city:
293+
if self.map[1][i][j] == BlockType.city or self.map[1][i][j] == BlockType.mountain:
288294
map_filtered[1][i][j] = BlockType.obstacle
289295
else:
290296
map_filtered[0][i][j] = self.map[0][i][j]
@@ -341,10 +347,12 @@ def win_check(self):
341347
alive = []
342348
for i in range(self.map_size):
343349
for j in range(self.map_size):
350+
if int(self.map[2][i][j]) == PlayerColor.grey:
351+
continue
344352
if int(self.map[2][i][j]) not in alive:
345353
alive.append(int(self.map[2][i][j]))
346354
if len(alive) > 1:
347355
return 0
348356
if alive[0] == self.learningbot_color:
349357
return 2
350-
return 1
358+
return 1

onsite_env.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def reset(self):
7171
def step(self, action: torch.Tensor):
7272
"""
7373
执行一步
74-
:param action: movement => tensor([[x1, y1, x2, y2, is_half]])
74+
:param action: movement => tensor([[x1, y1, x2, y2, is_half]]) 注意x,y和i,j正好相反
7575
:return: observation (Tensor), reward (float), done (bool), info (dict)
7676
"""
7777
reward = 0
@@ -101,16 +101,19 @@ def step(self, action: torch.Tensor):
101101
last_move = self.action_history.queue[-1]
102102
last_map = self.map_history.queue[-1]
103103
# 无效移动扣大分
104-
if last_map[2][last_move[0] - 1][last_move[1] - 1] != self._get_colormark(self.self_color):
104+
if last_map[2][last_move[1] - 1][last_move[0] - 1] != self._get_colormark(self.self_color):
105105
reward -= 100
106+
# 撞山扣一点
107+
if self.map[1][last_move[3] - 1][last_move[2] - 1] == BlockType.mountain:
108+
reward -= 10
106109
# 撞塔扣分
107-
if self.map[1][last_move[2] - 1][last_move[3] - 1] == BlockType.city:
108-
if self.map[2][last_move[2] - 1][last_move[3] - 1] != self._get_colormark(self.self_color):
110+
if self.map[1][last_move[3] - 1][last_move[2] - 1] == BlockType.city:
111+
if self.map[2][last_move[3] - 1][last_move[2] - 1] != self._get_colormark(self.self_color):
109112
reward -= 10
110113
# 探索新领地加分 注意 不是占领
111114
for i in range(8):
112-
t_x = last_move[2] - 1 + _dirx[i]
113-
t_y = last_move[3] - 1 + _diry[i]
115+
t_x = last_move[3] - 1 + _dirx[i]
116+
t_y = last_move[2] - 1 + _diry[i]
114117
if t_x < 0 or t_x >= self.map_size or t_y < 0 or t_y >= self.map_size:
115118
continue
116119
if self.map[3][t_x][t_y] - last_map[3][t_x][t_y] == 1:
@@ -240,10 +243,14 @@ def update_map(self, _init_flag=False):
240243
def move(self, mov):
241244
"""
242245
just as the name
243-
:param mov: tensor([[x1, y1, x2, y2, is_half]])
246+
:param mov: tensor([[x1, y1, x2, y2, is_half]]) 注意x,y和i,j正好相反
244247
:return:
245248
"""
246249
move_info = mov[0].long()
250+
# 先交换 将x,y坐标转换为i,j坐标
251+
move_info[0], move_info[1] = move_info[1], move_info[0]
252+
move_info[2], move_info[3] = move_info[3], move_info[2]
253+
247254
if self.selected[0] != move_info[0] - 1 or self.selected[1] != move_info[1] - 1:
248255
# 如果没选中 先点一下
249256
self.driver.find_element_by_id(f"td-{int((move_info[0] - 1) * self.map_size + move_info[1])}").click()

utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,8 @@ def print_tensor_map(game_map):
131131
size = game_map.shape[1]
132132
for i in range(size):
133133
for j in range(size):
134-
if game_map[1][i][j] == BlockType.city or game_map[1][i][j] == BlockType.mountain:
134+
if game_map[2][i][j] == PlayerColor.grey and \
135+
game_map[1][i][j] == BlockType.city or game_map[1][i][j] == BlockType.mountain:
135136
bg = 40
136137
else:
137138
bg = color_trans[int(game_map[2][i][j])]

0 commit comments

Comments
 (0)