Skip to content

Commit 71d9f38

Browse files
modify runner and evaluation code
1 parent 63fdb13 commit 71d9f38

File tree

2 files changed

+23
-8
lines changed

2 files changed

+23
-8
lines changed

python/dnlp/utils/evaluation.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# -*- coding: UTF-8 -*-
22
import pickle
3+
from sklearn.metrics import f1_score,precision_score,recall_score
34
from dnlp.utils.constant import TAG_BEGIN, TAG_INSIDE, TAG_END, TAG_SINGLE
45

56

@@ -75,13 +76,26 @@ def evaluate_cws(model, data_path: str):
7576
characters = data['characters']
7677
labels_true = data['labels']
7778
c_count = 0
79+
7880
p_count = 0
81+
7982
r_count = 0
83+
84+
all_labels_true = []
85+
all_labels_predict = []
8086
for sentence, label in zip(characters, labels_true):
81-
words, labels_predict = model.predict(sentence, return_labels=True)
87+
words, labels_predict = model.predict_ll(sentence, return_labels=True)
88+
#print("============")
89+
#print(words)
90+
all_labels_predict.extend(labels_predict)
91+
all_labels_true.extend(label)
8292
c, p, r = get_cws_statistics(label, labels_predict)
8393
c_count += c
8494
p_count += p
8595
r_count += r
8696
print(c_count / p_count)
8797
print(c_count / r_count)
98+
average = 'macro'
99+
print(precision_score(all_labels_true,all_labels_predict,average=average))
100+
print(recall_score(all_labels_true,all_labels_predict,average=average))
101+

python/scripts/cws_ner.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,21 @@
77

88

99
def train_cws():
10-
data_path = '../dnlp/data/cws/pku_training.pickle'
10+
data_path = '../dnlp/data/cws/msr_training.pickle'
1111
config = DnnCrfConfig()
12-
dnncrf = DnnCrf(config=config, data_path=data_path, nn='lstm')
13-
dnncrf.fit_ll()
12+
dnncrf = DnnCrf(config=config, data_path=data_path, nn='bilstm')
13+
dnncrf.fit()
1414

1515

1616
def test_cws():
1717
sentence = '小明来自南京师范大学'
18-
model_path = '../dnlp/models/cws1.ckpt'
18+
sentence = '中国人民决心继承邓小平同志的遗志,继续把建设有中国特色社会主义事业推向前进。'
19+
model_path = '../dnlp/models/cws4.ckpt'
1920
config = DnnCrfConfig()
20-
dnncrf = DnnCrf(config=config, mode='predict', model_path=model_path, nn='lstm')
21-
res, labels = dnncrf.predict(sentence, return_labels=True)
21+
dnncrf = DnnCrf(config=config, mode='predict', model_path=model_path, nn='bilstm')
22+
res, labels = dnncrf.predict_ll(sentence, return_labels=True)
2223
print(res)
23-
evaluate_cws(dnncrf, '../dnlp/data/cws/pku_test.pickle')
24+
evaluate_cws(dnncrf, '../dnlp/data/cws/msr_test.pickle')
2425

2526

2627
if __name__ == '__main__':

0 commit comments

Comments
 (0)