Skip to content

Commit a5d97ea

Browse files
committed
perf(example): 使用pynet内置的绘制类Draw
1 parent 04fff8e commit a5d97ea

File tree

6 files changed

+30
-18
lines changed

6 files changed

+30
-18
lines changed

examples/2_nn_mnist.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pynet.optim as optim
99
import pynet.nn as nn
1010
from pynet.vision.data import mnist
11-
import plt
11+
from pynet.vision import Draw
1212

1313
data_path = '/home/zj/data/decompress_mnist'
1414
# data_path = '/home/lab305/Documents/zj/data/mnist'
@@ -33,6 +33,8 @@
3333
solver = pynet.Solver(model, data, criterion, optimizer, batch_size=128, num_epochs=10)
3434
solver.train()
3535

36-
plt.draw_loss(solver.loss_history)
37-
plt.draw_acc((solver.train_acc_history, solver.val_acc_history), ('train', 'val'))
36+
plt = Draw()
37+
plt(solver.loss_history)
38+
plt.multi_plot((solver.train_acc_history, solver.val_acc_history), ('train', 'val'),
39+
title='准确率', xlabel='迭代/次', ylabel='准确率')
3840
print('best_train_acc: %f; best_val_acc: %f' % (solver.best_train_acc, solver.best_val_acc))

examples/3_nn_cifar10.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pynet.optim as optim
1010
from pynet.vision.data import cifar
1111
import pynet.nn as nn
12-
import plt
12+
from pynet.vision import Draw
1313

1414
data_path = '/home/lab305/Documents/zj/data/cifar_10/cifar-10-batches-py'
1515

@@ -24,6 +24,8 @@
2424
reg=1e-3, print_every=1)
2525
solver.train()
2626

27-
plt.draw_loss(solver.loss_history)
28-
plt.draw_acc((solver.train_acc_history, solver.val_acc_history), ('train', 'val'))
27+
plt = Draw()
28+
plt(solver.loss_history)
29+
plt.multi_plot((solver.train_acc_history, solver.val_acc_history), ('train', 'val'),
30+
title='准确率', xlabel='迭代/次', ylabel='准确率')
2931
print('best_train_acc: %f; best_val_acc: %f' % (solver.best_train_acc, solver.best_val_acc))

examples/3_nn_iris.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import pynet.models as models
1111
import pynet.nn as nn
1212
from pynet.vision.data import iris
13-
import plt
13+
from pynet.vision import Draw
1414

1515
data_path = '/home/zj/data/iris-species/Iris.csv'
1616

@@ -36,6 +36,8 @@
3636
reg=1e-3, print_every=500)
3737
solver.train()
3838

39-
plt.draw_loss(solver.loss_history)
40-
plt.draw_acc((solver.train_acc_history, solver.val_acc_history), ('train', 'val'))
39+
plt = Draw()
40+
plt(solver.loss_history)
41+
plt.multi_plot((solver.train_acc_history, solver.val_acc_history), ('train', 'val'),
42+
title='准确率', xlabel='迭代/次', ylabel='准确率')
4143
print('best_train_acc: %f; best_val_acc: %f' % (solver.best_train_acc, solver.best_val_acc))

examples/3_nn_mnist.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pynet.optim as optim
99
import pynet.nn as nn
1010
from pynet.vision.data import mnist
11-
import plt
11+
from pynet.vision import Draw
1212

1313
data_path = '/home/zj/data/decompress_mnist'
1414

@@ -32,6 +32,8 @@
3232
solver = pynet.Solver(model, data, criterion, optimizer, batch_size=256, num_epochs=10, print_every=1, reg=1e-3)
3333
solver.train()
3434

35-
plt.draw_loss(solver.loss_history)
36-
plt.draw_acc((solver.train_acc_history, solver.val_acc_history), ('train', 'val'))
35+
plt = Draw()
36+
plt(solver.loss_history)
37+
plt.multi_plot((solver.train_acc_history, solver.val_acc_history), ('train', 'val'),
38+
title='准确率', xlabel='迭代/次', ylabel='准确率')
3739
print('best_train_acc: %f; best_val_acc: %f' % (solver.best_train_acc, solver.best_val_acc))

examples/3_nn_orl.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pynet.nn as nn
1010
import pynet.optim as optim
1111
from pynet.vision.data import orl
12-
import plt
12+
from pynet.vision import Draw
1313

1414
data_path = '/home/zj/data/att_faces_png'
1515

@@ -34,6 +34,8 @@
3434
reg=1e-3, print_every=1)
3535
solver.train()
3636

37-
plt.draw_loss(solver.loss_history)
38-
plt.draw_acc((solver.train_acc_history, solver.val_acc_history), ('train', 'val'))
37+
plt = Draw()
38+
plt(solver.loss_history)
39+
plt.multi_plot((solver.train_acc_history, solver.val_acc_history), ('train', 'val'),
40+
title='准确率', xlabel='迭代/次', ylabel='准确率')
3941
print('best_train_acc: %f; best_val_acc: %f' % (solver.best_train_acc, solver.best_val_acc))

examples/lenet5_mnist.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pynet.optim as optim
99
import pynet.nn as nn
1010
from pynet.vision.data import mnist
11-
import plt
11+
from pynet.vision import Draw
1212

1313
# data_path = '/home/zj/data/decompress_mnist'
1414
data_path = '/home/lab305/Documents/zj/data/mnist'
@@ -35,6 +35,8 @@
3535
lr_scheduler=stepLR, batch_size=128, num_epochs=10, print_every=1)
3636
solver.train()
3737

38-
plt.draw_loss(solver.loss_history)
39-
plt.draw_acc((solver.train_acc_history, solver.val_acc_history), ('train', 'val'))
38+
plt = Draw()
39+
plt(solver.loss_history)
40+
plt.multi_plot((solver.train_acc_history, solver.val_acc_history), ('train', 'val'),
41+
title='准确率', xlabel='迭代/次', ylabel='准确率')
4042
print('best_train_acc: %f; best_val_acc: %f' % (solver.best_train_acc, solver.best_val_acc))

0 commit comments

Comments
 (0)