Skip to content

Commit 73001ee

Browse files
fix dnn-crf prediction bug
1 parent 2e755ad commit 73001ee

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

python/dnlp/core/dnn_crf_base.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
import numpy as np
33
import pickle
44
from dnlp.config.config import DnnCrfConfig
5-
from dnlp.utils.constant import BATCH_PAD, STRT_VAL, END_VAL, TAG_PAD, TAG_BEGIN, TAG_INSIDE, TAG_SINGLE
5+
from dnlp.utils.constant import BATCH_PAD, UNK, STRT_VAL, END_VAL, TAG_PAD, TAG_BEGIN, TAG_INSIDE, TAG_SINGLE
66

77

88
class DnnCrfBase(object):
9-
def __init__(self, config: DnnCrfConfig=None, data_path: str = '', mode: str = 'train', model_path: str = ''):
9+
def __init__(self, config: DnnCrfConfig = None, data_path: str = '', mode: str = 'train', model_path: str = ''):
1010
# 加载数据
1111
self.data_path = data_path
1212
self.config_suffix = '.config.pickle'
@@ -18,7 +18,7 @@ def __init__(self, config: DnnCrfConfig=None, data_path: str = '', mode: str = '
1818
self.dictionary, self.tags = self.__load_config()
1919
self.tags_count = len(self.tags) - 1 # 忽略TAG_PAD
2020
self.tags_map = self.__generate_tag_map()
21-
self.reversed_tags_map = dict(zip(self.tags_map.values(),self.tags_map.keys()))
21+
self.reversed_tags_map = dict(zip(self.tags_map.values(), self.tags_map.keys()))
2222
self.dict_size = len(self.dictionary)
2323
# 初始化超参数
2424
self.skip_left = config.skip_left
@@ -82,7 +82,7 @@ def get_batch(self) -> (np.ndarray, np.ndarray, np.ndarray):
8282
else:
8383
ext_size = self.batch_length - len(chs)
8484
chs_batch[i] = chs + ext_size * [self.dictionary[BATCH_PAD]]
85-
lls_batch[i] = list(map(lambda t: self.tags_map[t], lls)) + ext_size * [0]#[self.tags_map[TAG_PAD]]
85+
lls_batch[i] = list(map(lambda t: self.tags_map[t], lls)) + ext_size * [0] # [self.tags_map[TAG_PAD]]
8686

8787
self.batch_start = new_start
8888
return self.indices2input(chs_batch), np.array(lls_batch, dtype=np.int32), np.array(len_batch, dtype=np.int32)
@@ -111,7 +111,8 @@ def viterbi(self, emission: np.ndarray, transition: np.ndarray, transition_init:
111111
return corr_path
112112

113113
def sentence2indices(self, sentence: str) -> list:
114-
return list(map(lambda ch: self.dictionary[ch], sentence))
114+
expr = lambda ch: self.dictionary[ch] if ch in self.dictionary else self.dictionary[UNK]
115+
return list(map(expr, sentence))
115116

116117
def indices2input(self, indices: list) -> np.ndarray:
117118
res = []
@@ -173,10 +174,10 @@ def tags2entities(self, sentence: str, tags_seq: np.ndarray, return_start: bool
173174
else:
174175
return entities
175176

176-
def tag2sequences(self, tags_seq:np.ndarray):
177+
def tag2sequences(self, tags_seq: np.ndarray):
177178
seq = []
178179

179180
for tag in tags_seq:
180181
seq.append(self.reversed_tags_map[tag])
181182

182-
return seq
183+
return seq

0 commit comments

Comments
 (0)