Skip to content

Commit 103a5a4

Browse files
committed
refactor(nesterov): 修改卷积层和全连接层的nesterov实现
1 parent c8ca7e6 commit 103a5a4

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

nn/layers.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,13 @@ def __init__(self, in_c, filter_h, filter_w, filter_num, stride=1, padding=0, mo
3939
self.filter_num = filter_num
4040
self.stride = stride
4141
self.padding = padding
42-
self.nesterov = nesterov
4342

4443
self.W = \
4544
{'val': 0.01 * np.random.normal(loc=0, scale=1.0, size=(filter_h * filter_w * in_c, filter_num)),
4645
'grad': 0,
4746
'v': 0,
48-
'momentum': momentum}
47+
'momentum': momentum,
48+
'nesterov': nesterov}
4949
self.b = {'val': 0.01 * np.random.normal(loc=0, scale=1.0, size=(1, filter_num)), 'grad': 0}
5050
self.a = None
5151
self.input_shape = None
@@ -84,18 +84,21 @@ def update(self, learning_rate=0, regularization_rate=0):
8484
v_prev = self.W['v']
8585
self.W['v'] = self.W['momentum'] * self.W['v'] - learning_rate * (
8686
self.W['grad'] + regularization_rate * self.W['val'])
87-
if self.nesterov:
87+
if self.W['nesterov']:
8888
self.W['val'] += (1 + self.W['momentum']) * self.W['v'] - self.W['momentum'] * v_prev
8989
else:
9090
self.W['val'] += self.W['v']
9191
self.b['val'] -= learning_rate * (self.b['grad'])
9292

9393
def get_params(self):
94-
return {'W': self.W['val'], 'b': self.b['val']}
94+
return {'W': self.W['val'], 'momentum': self.W['momentum'], 'nesterov': self.W['nesterov'], 'b': self.b['val']}
9595

9696
def set_params(self, params):
97-
self.W['val'] = params['W']
98-
self.b['val'] = params['b']
97+
self.W['val'] = params.get('W')
98+
self.b['val'] = params.get('b')
99+
100+
self.W['momentum'] = params.get('momentum', 0.0)
101+
self.W['nesterov'] = params.get('nesterov', False)
99102

100103

101104
class MaxPool(Layer):
@@ -159,11 +162,11 @@ def __init__(self, num_in, num_out, momentum=0, nesterov=False):
159162
assert isinstance(num_in, int) and num_in > 0
160163
assert isinstance(num_out, int) and num_out > 0
161164

162-
self.nesterov = nesterov
163165
self.W = {'val': 0.01 * np.random.normal(loc=0, scale=1.0, size=(num_in, num_out)),
164166
'grad': 0,
165167
'v': 0,
166-
'momentum': momentum}
168+
'momentum': momentum,
169+
'nesterov': nesterov}
167170
self.b = {'val': 0.01 * np.random.normal(loc=0, scale=1.0, size=(1, num_out)), 'grad': 0}
168171
self.inputs = None
169172

@@ -189,18 +192,21 @@ def update(self, learning_rate=0, regularization_rate=0):
189192
v_prev = self.W['v']
190193
self.W['v'] = self.W['momentum'] * self.W['v'] - learning_rate * (
191194
self.W['grad'] + regularization_rate * self.W['val'])
192-
if self.nesterov:
195+
if self.W['nesterov']:
193196
self.W['val'] += (1 + self.W['momentum']) * self.W['v'] - self.W['momentum'] * v_prev
194197
else:
195198
self.W['val'] += self.W['v']
196199
self.b['val'] -= learning_rate * self.b['grad']
197200

198201
def get_params(self):
199-
return {'W': self.W['val'], 'b': self.b['val']}
202+
return {'W': self.W['val'], 'momentum': self.W['momentum'], 'nesterov': self.W['nesterov'], 'b': self.b['val']}
200203

201204
def set_params(self, params):
202-
self.W['val'] = params['W']
203-
self.b['val'] = params['b']
205+
self.W['val'] = params.get('W')
206+
self.b['val'] = params.get('b')
207+
208+
self.W['momentum'] = params.get('momentum', 0.0)
209+
self.W['nesterov'] = params.get('nesterov', False)
204210

205211

206212
class ReLU(Layer):

0 commit comments

Comments
 (0)