File tree 2 files changed +23
-8
lines changed
2 files changed +23
-8
lines changed Original file line number Diff line number Diff line change 1
1
# -*- coding: UTF-8 -*-
2
2
import pickle
3
+ from sklearn .metrics import f1_score ,precision_score ,recall_score
3
4
from dnlp .utils .constant import TAG_BEGIN , TAG_INSIDE , TAG_END , TAG_SINGLE
4
5
5
6
@@ -75,13 +76,26 @@ def evaluate_cws(model, data_path: str):
75
76
characters = data ['characters' ]
76
77
labels_true = data ['labels' ]
77
78
c_count = 0
79
+
78
80
p_count = 0
81
+
79
82
r_count = 0
83
+
84
+ all_labels_true = []
85
+ all_labels_predict = []
80
86
for sentence , label in zip (characters , labels_true ):
81
- words , labels_predict = model .predict (sentence , return_labels = True )
87
+ words , labels_predict = model .predict_ll (sentence , return_labels = True )
88
+ #print("============")
89
+ #print(words)
90
+ all_labels_predict .extend (labels_predict )
91
+ all_labels_true .extend (label )
82
92
c , p , r = get_cws_statistics (label , labels_predict )
83
93
c_count += c
84
94
p_count += p
85
95
r_count += r
86
96
print (c_count / p_count )
87
97
print (c_count / r_count )
98
+ average = 'macro'
99
+ print (precision_score (all_labels_true ,all_labels_predict ,average = average ))
100
+ print (recall_score (all_labels_true ,all_labels_predict ,average = average ))
101
+
Original file line number Diff line number Diff line change 7
7
8
8
9
9
def train_cws ():
10
- data_path = '../dnlp/data/cws/pku_training .pickle'
10
+ data_path = '../dnlp/data/cws/msr_training .pickle'
11
11
config = DnnCrfConfig ()
12
- dnncrf = DnnCrf (config = config , data_path = data_path , nn = 'lstm ' )
13
- dnncrf .fit_ll ()
12
+ dnncrf = DnnCrf (config = config , data_path = data_path , nn = 'bilstm ' )
13
+ dnncrf .fit ()
14
14
15
15
16
16
def test_cws ():
17
17
sentence = '小明来自南京师范大学'
18
- model_path = '../dnlp/models/cws1.ckpt'
18
+ sentence = '中国人民决心继承邓小平同志的遗志,继续把建设有中国特色社会主义事业推向前进。'
19
+ model_path = '../dnlp/models/cws4.ckpt'
19
20
config = DnnCrfConfig ()
20
- dnncrf = DnnCrf (config = config , mode = 'predict' , model_path = model_path , nn = 'lstm ' )
21
- res , labels = dnncrf .predict (sentence , return_labels = True )
21
+ dnncrf = DnnCrf (config = config , mode = 'predict' , model_path = model_path , nn = 'bilstm ' )
22
+ res , labels = dnncrf .predict_ll (sentence , return_labels = True )
22
23
print (res )
23
- evaluate_cws (dnncrf , '../dnlp/data/cws/pku_test .pickle' )
24
+ evaluate_cws (dnncrf , '../dnlp/data/cws/msr_test .pickle' )
24
25
25
26
26
27
if __name__ == '__main__' :
You can’t perform that action at this time.
0 commit comments