8
8
9
9
class DnnCrf (DnnCrfBase ):
10
10
def __init__ (self , * , config : DnnCrfConfig = None , data_path : str = '' , dtype : type = tf .float32 , mode : str = 'train' ,
11
- train : str = 'll' , nn : str , model_path : str = '' ):
11
+ predict : str = 'll' , nn : str , model_path : str = '' ):
12
12
if mode not in ['train' , 'predict' ]:
13
13
raise Exception ('mode error' )
14
14
if nn not in ['mlp' , 'rnn' , 'lstm' , 'bilstm' , 'gru' ]:
@@ -27,113 +27,42 @@ def __init__(self, *, config: DnnCrfConfig = None, data_path: str = '', dtype: t
27
27
if mode == 'train' :
28
28
self .input = tf .placeholder (tf .int32 , [self .batch_size , self .batch_length , self .windows_size ])
29
29
self .real_indices = tf .placeholder (tf .int32 , [self .batch_size , self .batch_length ])
30
- self .seq_length = tf .placeholder (tf .int32 , [self .batch_size ])
31
30
else :
32
31
self .input = tf .placeholder (tf .int32 , [None , self .windows_size ])
33
32
33
+ self .seq_length = tf .placeholder (tf .int32 , [None ])
34
+
34
35
# 查找表层
35
36
self .embedding_layer = self .get_embedding_layer ()
36
37
# 隐藏层
37
38
if nn == 'mlp' :
38
39
self .hidden_layer = self .get_mlp_layer (tf .transpose (self .embedding_layer ))
39
40
elif nn == 'lstm' :
40
41
self .hidden_layer = self .get_lstm_layer (self .embedding_layer )
42
+ elif nn == 'bilstm' :
43
+ self .hidden_layer = self .get_bilstm_layer (self .embedding_layer )
41
44
elif nn == 'gru' :
42
- self .hidden_layer = self .get_gru_layer (tf . transpose ( self .embedding_layer ) )
45
+ self .hidden_layer = self .get_gru_layer (self .embedding_layer )
43
46
else :
44
- self .hidden_layer = self .get_rnn_layer (tf . transpose ( self .embedding_layer ) )
47
+ self .hidden_layer = self .get_rnn_layer (self .embedding_layer )
45
48
# 输出层
46
49
self .output = self .get_output_layer (self .hidden_layer )
47
50
48
51
if mode == 'predict' :
49
- self .output = tf .squeeze (tf .transpose (self .output ), axis = 2 )
52
+ if predict != 'll' :
53
+ self .output = tf .squeeze (tf .transpose (self .output ), axis = 2 )
54
+ self .seq , self .best_score = tf .contrib .crf .crf_decode (self .output , self .transition , self .seq_length )
50
55
self .sess = tf .Session ()
51
56
self .sess .run (tf .global_variables_initializer ())
52
57
tf .train .Saver ().restore (save_path = self .model_path , sess = self .sess )
53
- elif train == 'll' :
54
- self .ll_loss , _ = tf .contrib .crf .crf_log_likelihood (self .output , self .real_indices , self .seq_length ,
55
- self .transition )
56
- self .optimizer = tf .train .AdagradOptimizer (self .learning_rate )
57
- self .train_ll = self .optimizer .minimize (- self .ll_loss )
58
58
else :
59
- # 构建训练函数
60
- # 训练用placeholder
61
- self .ll_corr = tf .placeholder (tf .int32 , shape = [None , 3 ])
62
- self .ll_curr = tf .placeholder (tf .int32 , shape = [None , 3 ])
63
- self .trans_corr = tf .placeholder (tf .int32 , [None , 2 ])
64
- self .trans_curr = tf .placeholder (tf .int32 , [None , 2 ])
65
- self .trans_init_corr = tf .placeholder (tf .int32 , [None , 1 ])
66
- self .trans_init_curr = tf .placeholder (tf .int32 , [None , 1 ])
67
- # 损失函数
68
- self .loss , self .loss_with_init = self .get_loss ()
59
+ self .loss , _ = tf .contrib .crf .crf_log_likelihood (self .output , self .real_indices , self .seq_length ,
60
+ self .transition )
69
61
self .optimizer = tf .train .AdagradOptimizer (self .learning_rate )
70
- self .train = self . optimizer . minimize ( self . loss )
71
- self .train_with_init = self .optimizer .minimize (self .loss_with_init )
62
+ self .new_optimizer = tf . train . AdamOptimizer ( )
63
+ self .train = self .optimizer .minimize (- self .loss )
72
64
73
65
def fit (self , epochs : int = 100 , interval : int = 20 ):
74
- with tf .Session () as sess :
75
- tf .global_variables_initializer ().run ()
76
- saver = tf .train .Saver (max_to_keep = 100 )
77
- for epoch in range (1 , epochs + 1 ):
78
- print ('epoch:' , epoch )
79
- for _ in range (self .batch_count ):
80
- characters , labels , lengths = self .get_batch ()
81
- self .fit_batch (characters , labels , lengths , sess )
82
- # if epoch % interval == 0:
83
- model_path = '../dnlp/models/cws{0}.ckpt' .format (epoch )
84
- saver .save (sess , model_path )
85
- self .save_config (model_path )
86
-
87
- def fit_batch (self , characters , labels , lengths , sess ):
88
- scores = sess .run (self .output , feed_dict = {self .input : characters })
89
- transition = self .transition .eval (session = sess )
90
- transition_init = self .transition_init .eval (session = sess )
91
- update_labels_pos = None
92
- update_labels_neg = None
93
- current_labels = []
94
- trans_pos_indices = []
95
- trans_neg_indices = []
96
- trans_init_pos_indices = []
97
- trans_init_neg_indices = []
98
- for i in range (self .batch_size ):
99
- current_label = self .viterbi (scores [:, :lengths [i ], i ], transition , transition_init )
100
- current_labels .append (current_label )
101
- diff_tag = np .subtract (labels [i , :lengths [i ]], current_label )
102
- update_index = np .where (diff_tag != 0 )[0 ]
103
- update_length = len (update_index )
104
- if update_length == 0 :
105
- continue
106
- update_label_pos = np .stack ([labels [i , update_index ], update_index , i * np .ones ([update_length ])], axis = - 1 )
107
- update_label_neg = np .stack ([current_label [update_index ], update_index , i * np .ones ([update_length ])], axis = - 1 )
108
- if update_labels_pos is not None :
109
- np .concatenate ((update_labels_pos , update_label_pos ))
110
- np .concatenate ((update_labels_neg , update_label_neg ))
111
- else :
112
- update_labels_pos = update_label_pos
113
- update_labels_neg = update_label_neg
114
-
115
- trans_pos_index , trans_neg_index , trans_init_pos , trans_init_neg , update_init = self .generate_transition_update_index (
116
- labels [i , :lengths [i ]], current_labels [i ])
117
-
118
- trans_pos_indices .extend (trans_pos_index )
119
- trans_neg_indices .extend (trans_neg_index )
120
-
121
- if update_init :
122
- trans_init_pos_indices .append (trans_init_pos )
123
- trans_init_neg_indices .append (trans_init_neg )
124
-
125
- if update_labels_pos is not None and update_labels_neg is not None :
126
- feed_dict = {self .input : characters , self .ll_curr : update_labels_neg , self .ll_corr : update_labels_pos ,
127
- self .trans_curr : trans_neg_indices , self .trans_corr : trans_pos_indices }
128
-
129
- if not trans_init_pos_indices :
130
- sess .run (self .train , feed_dict )
131
- else :
132
- feed_dict [self .trans_init_corr ] = trans_init_pos_indices
133
- feed_dict [self .trans_init_curr ] = trans_init_neg_indices
134
- sess .run (self .train_with_init , feed_dict )
135
-
136
- def fit_ll (self , epochs : int = 100 , interval : int = 20 ):
137
66
with tf .Session () as sess :
138
67
tf .global_variables_initializer ().run ()
139
68
saver = tf .train .Saver (max_to_keep = epochs )
@@ -143,44 +72,13 @@ def fit_ll(self, epochs: int = 100, interval: int = 20):
143
72
characters , labels , lengths = self .get_batch ()
144
73
# scores = sess.run(self.output, feed_dict={self.input: characters})
145
74
feed_dict = {self .input : characters , self .real_indices : labels , self .seq_length : lengths }
146
- sess .run (self .train_ll , feed_dict = feed_dict )
75
+ sess .run (self .train , feed_dict = feed_dict )
147
76
# self.fit_batch(characters, labels, lengths, sess)
148
77
# if epoch % interval == 0:
149
78
model_path = '../dnlp/models/cws{0}.ckpt' .format (epoch )
150
79
saver .save (sess , model_path )
151
80
self .save_config (model_path )
152
81
153
- def fit_batch_ll (self ):
154
- pass
155
-
156
- def generate_transition_update_index (self , correct_labels , current_labels ):
157
- if correct_labels .shape != current_labels .shape :
158
- print ('sequence length is not equal' )
159
- return None
160
-
161
- before_corr = correct_labels [0 ]
162
- before_curr = current_labels [0 ]
163
- update_init = False
164
-
165
- trans_init_pos = None
166
- trans_init_neg = None
167
- trans_pos = []
168
- trans_neg = []
169
-
170
- if before_corr != before_curr :
171
- trans_init_pos = [before_corr ]
172
- trans_init_neg = [before_curr ]
173
- update_init = True
174
-
175
- for _ , (corr_label , curr_label ) in enumerate (zip (correct_labels [1 :], current_labels [1 :])):
176
- if corr_label != curr_label or before_corr != before_curr :
177
- trans_pos .append ([before_corr , corr_label ])
178
- trans_neg .append ([before_curr , curr_label ])
179
- before_corr = corr_label
180
- before_curr = curr_label
181
-
182
- return trans_pos , trans_neg , trans_init_pos , trans_init_neg , update_init
183
-
184
82
def predict (self , sentence : str , return_labels = False ):
185
83
if self .mode != 'predict' :
186
84
raise Exception ('mode is not allowed to predict' )
@@ -194,6 +92,22 @@ def predict(self, sentence: str, return_labels=False):
194
92
else :
195
93
return self .tags2words (sentence , labels ), self .tag2sequences (labels )
196
94
95
+ def predict_ll (self , sentence : str , return_labels = False ):
96
+ if self .mode != 'predict' :
97
+ raise Exception ('mode is not allowed to predict' )
98
+
99
+ input = self .indices2input (self .sentence2indices (sentence ))
100
+ runner = [self .seq , self .best_score , self .output , self .transition ]
101
+ labels , best_score , output , trans = self .sess .run (runner ,
102
+ feed_dict = {self .input : input , self .seq_length : [len (sentence )]})
103
+ # print(output)
104
+ # print(trans)
105
+ labels = np .squeeze (labels )
106
+ if return_labels :
107
+ return self .tags2words (sentence , labels ), self .tag2sequences (labels )
108
+ else :
109
+ return self .tags2words (sentence , labels )
110
+
197
111
def get_embedding_layer (self ) -> tf .Tensor :
198
112
embeddings = self .__get_variable ([self .dict_size , self .embed_size ], 'embeddings' )
199
113
self .params .append (embeddings )
@@ -215,19 +129,27 @@ def get_rnn_layer(self, layer: tf.Tensor) -> tf.Tensor:
215
129
rnn = tf .nn .rnn_cell .RNNCell (self .hidden_units )
216
130
rnn_output , rnn_out_state = tf .nn .dynamic_rnn (rnn , layer , dtype = self .dtype )
217
131
self .params += [v for v in tf .global_variables () if v .name .startswith ('rnn' )]
218
- return tf . transpose ( rnn_output )
132
+ return rnn_output
219
133
220
134
def get_lstm_layer (self , layer : tf .Tensor ) -> tf .Tensor :
221
135
lstm = tf .nn .rnn_cell .LSTMCell (self .hidden_units )
222
136
lstm_output , lstm_out_state = tf .nn .dynamic_rnn (lstm , layer , dtype = self .dtype )
223
137
self .params += [v for v in tf .global_variables () if v .name .startswith ('rnn' )]
224
138
return lstm_output
225
139
140
+ def get_bilstm_layer (self , layer : tf .Tensor ) -> tf .Tensor :
141
+ lstm_fw = tf .nn .rnn_cell .LSTMCell (self .hidden_units // 2 )
142
+ lstm_bw = tf .nn .rnn_cell .LSTMCell (self .hidden_units // 2 )
143
+ bilstm_output , bilstm_output_state = tf .nn .bidirectional_dynamic_rnn (lstm_fw , lstm_bw , layer , self .seq_length ,
144
+ dtype = self .dtype )
145
+ self .params += [v for v in tf .global_variables () if v .name .startswith ('rnn' )]
146
+ return tf .concat ([bilstm_output [0 ],bilstm_output [1 ]],- 1 )
147
+
226
148
def get_gru_layer (self , layer : tf .Tensor ) -> tf .Tensor :
227
149
gru = tf .nn .rnn_cell .GRUCell (self .hidden_units )
228
150
gru_output , gru_out_state = tf .nn .dynamic_rnn (gru , layer , dtype = self .dtype )
229
151
self .params += [v for v in tf .global_variables () if v .name .startswith ('rnn' )]
230
- return tf . transpose ( gru_output )
152
+ return gru_output
231
153
232
154
def get_dropout_layer (self , layer : tf .Tensor ) -> tf .Tensor :
233
155
return tf .layers .dropout (layer , self .dropout_rate )
@@ -238,17 +160,5 @@ def get_output_layer(self, layer: tf.Tensor) -> tf.Tensor:
238
160
self .params += [output_weight , output_bias ]
239
161
return tf .tensordot (layer , output_weight , [[2 ], [0 ]]) + output_bias
240
162
241
- def get_loss (self ) -> (tf .Tensor , tf .Tensor ):
242
- output_loss = tf .reduce_sum (tf .gather_nd (self .output , self .ll_curr ) - tf .gather_nd (self .output , self .ll_corr ))
243
- trans_loss = tf .gather_nd (self .transition , self .trans_curr ) - tf .gather_nd (self .transition , self .trans_corr )
244
- trans_i_curr = tf .gather_nd (self .transition_init , self .trans_init_curr )
245
- trans_i_corr = tf .gather_nd (self .transition_init , self .trans_init_corr )
246
- trans_init_loss = tf .reduce_sum (trans_i_curr - trans_i_corr )
247
- loss = output_loss + trans_loss
248
- regu = tf .contrib .layers .apply_regularization (tf .contrib .layers .l2_regularizer (self .lam ), self .params )
249
- l1 = loss + regu
250
- l2 = l1 + trans_init_loss
251
- return l1 , l2
252
-
253
163
def __get_variable (self , size , name ) -> tf .Variable :
254
164
return tf .Variable (tf .truncated_normal (size , stddev = 1.0 / math .sqrt (size [- 1 ]), dtype = self .dtype ), name = name )
0 commit comments