Skip to content

Commit 73bbcaa

Browse files
committed
test(nn): 3层神经网络添加nesterov加速
1 parent 103a5a4 commit 73bbcaa

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

nn/nets.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,13 @@ class ThreeLayerNet(Net):
8282
实现3层神经网络
8383
"""
8484

85-
def __init__(self, num_in, num_h_one, num_h_two, num_out, momentum=0, p_h=1.0):
85+
def __init__(self, num_in, num_h_one, num_h_two, num_out, momentum=0, nesterov=False, p_h=1.0, ):
8686
super(ThreeLayerNet, self).__init__()
87-
self.fc1 = FC(num_in, num_h_one, momentum=momentum)
87+
self.fc1 = FC(num_in, num_h_one, momentum=momentum, nesterov=nesterov)
8888
self.relu1 = ReLU()
89-
self.fc2 = FC(num_h_one, num_h_two, momentum=momentum)
89+
self.fc2 = FC(num_h_one, num_h_two, momentum=momentum, nesterov=nesterov)
9090
self.relu2 = ReLU()
91-
self.fc3 = FC(num_h_two, num_out, momentum=momentum)
91+
self.fc3 = FC(num_h_two, num_out, momentum=momentum, nesterov=nesterov)
9292
self.p_h = p_h
9393

9494
def __call__(self, inputs):

0 commit comments

Comments
 (0)