Skip to content

Commit 5243103

Browse files
committed
perf(draw): 保存绘制图片
1 parent f16c6fa commit 5243103

File tree

8 files changed

+12
-9
lines changed

8 files changed

+12
-9
lines changed

examples/2_nn_mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,5 @@
3535
plt = Draw()
3636
plt(solver.loss_history)
3737
plt.multi_plot((solver.train_acc_history, solver.val_acc_history), ('train', 'val'),
38-
title='准确率', xlabel='迭代/次', ylabel='准确率')
38+
title='准确率', xlabel='迭代/次', ylabel='准确率', save_path='acc.png')
3939
print('best_train_acc: %f; best_val_acc: %f' % (solver.best_train_acc, solver.best_val_acc))

examples/3_nn_cifar10.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,5 @@
2727
plt = Draw()
2828
plt(solver.loss_history)
2929
plt.multi_plot((solver.train_acc_history, solver.val_acc_history), ('train', 'val'),
30-
title='准确率', xlabel='迭代/次', ylabel='准确率')
30+
title='准确率', xlabel='迭代/次', ylabel='准确率', save_path='acc.png')
3131
print('best_train_acc: %f; best_val_acc: %f' % (solver.best_train_acc, solver.best_val_acc))

examples/3_nn_iris.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,5 @@
3939
plt = Draw()
4040
plt(solver.loss_history)
4141
plt.multi_plot((solver.train_acc_history, solver.val_acc_history), ('train', 'val'),
42-
title='准确率', xlabel='迭代/次', ylabel='准确率')
42+
title='准确率', xlabel='迭代/次', ylabel='准确率', save_path='acc.png')
4343
print('best_train_acc: %f; best_val_acc: %f' % (solver.best_train_acc, solver.best_val_acc))

examples/3_nn_mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,5 @@
3535
plt = Draw()
3636
plt(solver.loss_history)
3737
plt.multi_plot((solver.train_acc_history, solver.val_acc_history), ('train', 'val'),
38-
title='准确率', xlabel='迭代/次', ylabel='准确率')
38+
title='准确率', xlabel='迭代/次', ylabel='准确率', save_path='acc.png')
3939
print('best_train_acc: %f; best_val_acc: %f' % (solver.best_train_acc, solver.best_val_acc))

examples/3_nn_orl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,5 @@
3737
plt = Draw()
3838
plt(solver.loss_history)
3939
plt.multi_plot((solver.train_acc_history, solver.val_acc_history), ('train', 'val'),
40-
title='准确率', xlabel='迭代/次', ylabel='准确率')
40+
title='准确率', xlabel='迭代/次', ylabel='准确率', save_path='acc.png')
4141
print('best_train_acc: %f; best_val_acc: %f' % (solver.best_train_acc, solver.best_val_acc))

examples/lenet5_mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,5 @@
3737
plt = Draw()
3838
plt(solver.loss_history)
3939
plt.multi_plot((solver.train_acc_history, solver.val_acc_history), ('train', 'val'),
40-
title='准确率', xlabel='迭代/次', ylabel='准确率')
40+
title='准确率', xlabel='迭代/次', ylabel='准确率', save_path='acc.png')
4141
print('best_train_acc: %f; best_val_acc: %f' % (solver.best_train_acc, solver.best_val_acc))

examples/nin_cifar10.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ def nin_train():
7878

7979
draw = vision.Draw()
8080
draw(loss_list, xlabel='迭代/20次')
81-
draw.multi_plot((train_list, test_list), ('训练集', '测试集'), title='精度图', xlabel='迭代/20次', ylabel='精度值')
81+
draw.multi_plot((train_list, test_list), ('训练集', '测试集'),
82+
title='精度图', xlabel='迭代/20次', ylabel='精度值', save_path='acc.png')
8283

8384

8485
if __name__ == '__main__':

pynet/vision/draw.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,16 @@ class Draw(object):
1212
def __call__(self, values, title='损失图', xlabel='迭代/次', ylabel='损失值'):
1313
self.forward(values, title='损失图', xlabel='迭代/次', ylabel='损失值')
1414

15-
def forward(self, values, title='损失图', xlabel='迭代/次', ylabel='损失值'):
15+
def forward(self, values, title='损失图', xlabel='迭代/次', ylabel='损失值', save_path='./loss.png'):
1616
assert isinstance(values, list)
1717
plt.title(title)
1818
plt.ylabel(ylabel)
1919
plt.xlabel(xlabel)
2020
plt.plot(values)
21+
plt.savefig(save_path)
2122
plt.show()
2223

23-
def multi_plot(self, values_list, labels_list, title='损失图', xlabel='迭代/次', ylabel='损失值'):
24+
def multi_plot(self, values_list, labels_list, title='损失图', xlabel='迭代/次', ylabel='损失值', save_path='./loss.png'):
2425
assert isinstance(values_list, tuple)
2526
assert isinstance(labels_list, tuple)
2627
assert len(values_list) == len(labels_list)
@@ -31,4 +32,5 @@ def multi_plot(self, values_list, labels_list, title='损失图', xlabel='迭代
3132
for i in range(len(values_list)):
3233
plt.plot(values_list[i], label=labels_list[i])
3334
plt.legend()
35+
plt.savefig(save_path)
3436
plt.show()

0 commit comments

Comments
 (0)