2
2
import numpy as np
3
3
import pickle
4
4
from dnlp .config .config import DnnCrfConfig
5
- from dnlp .utils .constant import BATCH_PAD , STRT_VAL , END_VAL , TAG_PAD , TAG_BEGIN , TAG_INSIDE , TAG_SINGLE
5
+ from dnlp .utils .constant import BATCH_PAD , UNK , STRT_VAL , END_VAL , TAG_PAD , TAG_BEGIN , TAG_INSIDE , TAG_SINGLE
6
6
7
7
8
8
class DnnCrfBase (object ):
9
- def __init__ (self , config : DnnCrfConfig = None , data_path : str = '' , mode : str = 'train' , model_path : str = '' ):
9
+ def __init__ (self , config : DnnCrfConfig = None , data_path : str = '' , mode : str = 'train' , model_path : str = '' ):
10
10
# 加载数据
11
11
self .data_path = data_path
12
12
self .config_suffix = '.config.pickle'
@@ -18,7 +18,7 @@ def __init__(self, config: DnnCrfConfig=None, data_path: str = '', mode: str = '
18
18
self .dictionary , self .tags = self .__load_config ()
19
19
self .tags_count = len (self .tags ) - 1 # 忽略TAG_PAD
20
20
self .tags_map = self .__generate_tag_map ()
21
- self .reversed_tags_map = dict (zip (self .tags_map .values (),self .tags_map .keys ()))
21
+ self .reversed_tags_map = dict (zip (self .tags_map .values (), self .tags_map .keys ()))
22
22
self .dict_size = len (self .dictionary )
23
23
# 初始化超参数
24
24
self .skip_left = config .skip_left
@@ -82,7 +82,7 @@ def get_batch(self) -> (np.ndarray, np.ndarray, np.ndarray):
82
82
else :
83
83
ext_size = self .batch_length - len (chs )
84
84
chs_batch [i ] = chs + ext_size * [self .dictionary [BATCH_PAD ]]
85
- lls_batch [i ] = list (map (lambda t : self .tags_map [t ], lls )) + ext_size * [0 ]# [self.tags_map[TAG_PAD]]
85
+ lls_batch [i ] = list (map (lambda t : self .tags_map [t ], lls )) + ext_size * [0 ] # [self.tags_map[TAG_PAD]]
86
86
87
87
self .batch_start = new_start
88
88
return self .indices2input (chs_batch ), np .array (lls_batch , dtype = np .int32 ), np .array (len_batch , dtype = np .int32 )
@@ -111,7 +111,8 @@ def viterbi(self, emission: np.ndarray, transition: np.ndarray, transition_init:
111
111
return corr_path
112
112
113
113
def sentence2indices (self , sentence : str ) -> list :
114
- return list (map (lambda ch : self .dictionary [ch ], sentence ))
114
+ expr = lambda ch : self .dictionary [ch ] if ch in self .dictionary else self .dictionary [UNK ]
115
+ return list (map (expr , sentence ))
115
116
116
117
def indices2input (self , indices : list ) -> np .ndarray :
117
118
res = []
@@ -173,10 +174,10 @@ def tags2entities(self, sentence: str, tags_seq: np.ndarray, return_start: bool
173
174
else :
174
175
return entities
175
176
176
- def tag2sequences (self , tags_seq :np .ndarray ):
177
+ def tag2sequences (self , tags_seq : np .ndarray ):
177
178
seq = []
178
179
179
180
for tag in tags_seq :
180
181
seq .append (self .reversed_tags_map [tag ])
181
182
182
- return seq
183
+ return seq
0 commit comments