6
6
class PPOAgent (object ):
7
7
8
8
def __init__ (self , args : dict ):
9
- self .batch_size = args ["batch_size" ] # batch size
10
- self .lr_a = args ["lr_a" ] # Learning rate of actor
11
- self .lr_c = args ["lr_c" ] # Learning rate of critic
12
- self .gamma = args ["gamma" ] # Discount factor
13
- self .lamda = args ["lambda" ] # GAE parameter
14
- self .epsilon = args ["epsilon" ] # PPO clip parameter
15
- self .k_epochs = args ["k_epochs" ] # PPO parameter
9
+ self .batch_size = args ["batch_size" ] # batch size
10
+ self .lr_a = args ["lr_a" ] # 策略网络学习率
11
+ self .lr_c = args ["lr_c" ] # 价值网络学习率
12
+ self .gamma = args ["gamma" ] # 折扣因子
13
+ self .lamda = args ["lambda" ] # GAE λ
14
+ self .epsilon = args ["epsilon" ] # PPO ε
15
+ self .k_epochs = args ["k_epochs" ] # PPO 训练轮数
16
16
self .entropy_coef = args ["entropy_coef" ]
17
- self .device = args ["device" ] # device
17
+ self .device = args ["device" ] # 运行设备
18
18
19
- # networks
19
+ # 神经网络
20
20
self .pai_set = {
21
21
20 : get_model ("actor" , "./model/non_maze.pth" , 20 ),
22
22
19 : get_model ("actor" , "./model/maze.pth" , 19 ),
@@ -46,10 +46,10 @@ def learn(self, rep, step_t):
46
46
"""
47
47
s , a , a_log_prob , r , s_ , dw , done = rep .to_tensor ()
48
48
49
- # calculate GAE advantage
49
+ # 利用GAE计算优势函数
50
50
adv = []
51
51
gae = 0
52
- with torch .no_grad (): # adv and v_target have no gradient
52
+ with torch .no_grad (): # 不需要梯度
53
53
vs = self .v (s )
54
54
vs_ = self .v (s_ )
55
55
deltas = r + self .gamma * (1.0 - dw ) * vs_ - vs
@@ -59,44 +59,46 @@ def learn(self, rep, step_t):
59
59
adv = torch .tensor (adv , dtype = torch .float ).view (- 1 , 1 )
60
60
v_target = adv + vs
61
61
62
- # advantage normalization
62
+ # 优势归一化
63
63
adv = ((adv - adv .mean ()) / (adv .std () + 1e-5 ))
64
64
65
- # Optimize policy for K epochs:
65
+ # 参数更新k轮
66
66
for _ in range (self .k_epochs ):
67
67
for index in BatchSampler (SubsetRandomSampler (range (self .batch_size )), self .batch_size , False ):
68
68
dist_now = Categorical (self .pai (s [index ]))
69
- dist_entropy = dist_now .entropy ().view (- 1 , 1 ) # shape(batch_size X 1)
70
- a_log_prob_now = dist_now .log_prob (a [index ].squeeze ()).view (- 1 , 1 ) # shape(batch_size X 1)
69
+ dist_entropy = dist_now .entropy ().view (- 1 , 1 ) # shape(batch_size x 1)
70
+ a_log_prob_now = dist_now .log_prob (a [index ].squeeze ()).view (- 1 , 1 ) # shape(batch_size x 1)
71
71
72
72
# https://www.luogu.com.cn/paste/9vwi6ls0
73
- ratios = torch .exp (a_log_prob_now - a_log_prob [index ]) # shape(batch_size X 1)
73
+ # 计算策略梯度
74
+ ratios = torch .exp (a_log_prob_now - a_log_prob [index ]) # shape(batch_size x 1)
74
75
surr1 = ratios * adv [index ]
75
76
surr2 = torch .clamp (ratios , 1 - self .epsilon , 1 + self .epsilon ) * adv [index ]
76
77
actor_loss = - torch .min (surr1 ,
77
- surr2 ) - self .entropy_coef * dist_entropy # shape(batch_size X 1)
78
- # Update actor
78
+ surr2 ) - self .entropy_coef * dist_entropy # shape(batch_size x 1)
79
+ # 更新策略网络
79
80
self .optimizer_actor .zero_grad ()
80
81
actor_loss .mean ().backward ()
81
- # Gradient clip
82
+ # 梯度裁剪
82
83
torch .nn .utils .clip_grad_norm_ (self .pai .parameters (), 0.5 )
83
84
self .optimizer_actor .step ()
84
85
86
+ # 价值网络梯度
85
87
v_s = self .v (s [index ])
86
88
critic_loss = self .mse_loss_fn (v_target [index ], v_s )
87
- # Update critic
89
+ # 更新价值网络
88
90
self .optimizer_critic .zero_grad ()
89
91
critic_loss .backward ()
90
- # Gradient clip
92
+ # 梯度裁剪
91
93
torch .nn .utils .clip_grad_norm_ (self .v .parameters (), 0.5 )
92
94
self .optimizer_critic .step ()
93
95
94
96
self .lr_decay (step_t )
95
97
96
98
def lr_decay (self , total_steps ):
97
99
"""
98
- learning rate decay
99
- :param total_steps:
100
+ 学习率衰减
101
+ :param total_steps: 已训练步数
100
102
:return:
101
103
"""
102
104
decay_rate = 0.1
@@ -110,7 +112,7 @@ def lr_decay(self, total_steps):
110
112
111
113
def predict (self , observation ):
112
114
"""
113
- sample an action from policy network
115
+ 从策略网络采样动作
114
116
:param observation: s_t
115
117
:return: 2 tensors: action, ln(p(a_t|s_t))
116
118
"""
@@ -122,7 +124,7 @@ def predict(self, observation):
122
124
123
125
def change_network (self , map_size ):
124
126
"""
125
- change policy and value network for a new game
127
+ 当模式更换时 更换神经网络
126
128
:param map_size:
127
129
:return:
128
130
"""
@@ -131,21 +133,21 @@ def change_network(self, map_size):
131
133
132
134
def warm_up (self ):
133
135
"""
134
- warm up neural networks
136
+ 预热 因为神经网络第一次跑会比较慢
135
137
:return:
136
138
"""
137
139
t = torch .zeros ([1 , 12 , 20 , 20 ]).to (self .device )
138
140
self .pai (t )
139
141
self .v (t )
140
142
141
143
def save (self ):
142
- # save policy networks
144
+ # 保存策略网络
143
145
torch .save (self .pai_set [20 ], "./model/non_maze.pth" )
144
146
torch .save (self .pai_set [10 ], "./model/non_maze1v1.pth" )
145
147
torch .save (self .pai_set [19 ], "./model/maze.pth" )
146
148
torch .save (self .pai_set [9 ], "./model/maze1v1.pth" )
147
149
148
- # save value networks
150
+ # 保存价值网络
149
151
torch .save (self .v_set [20 ], "./model/non_maze_critic.pth" )
150
152
torch .save (self .v_set [10 ], "./model/non_maze1v1_critic.pth" )
151
153
torch .save (self .v_set [19 ], "./model/maze_critic.pth" )
0 commit comments