Skip to content

Commit 63fdb13

Browse files
remove former train method , instead with tensorflow's log likehood, add bilstm
1 parent 73001ee commit 63fdb13

File tree

1 file changed

+41
-131
lines changed

1 file changed

+41
-131
lines changed

python/dnlp/core/dnn_crf.py

Lines changed: 41 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
class DnnCrf(DnnCrfBase):
1010
def __init__(self, *, config: DnnCrfConfig = None, data_path: str = '', dtype: type = tf.float32, mode: str = 'train',
11-
train: str = 'll', nn: str, model_path: str = ''):
11+
predict: str = 'll', nn: str, model_path: str = ''):
1212
if mode not in ['train', 'predict']:
1313
raise Exception('mode error')
1414
if nn not in ['mlp', 'rnn', 'lstm', 'bilstm', 'gru']:
@@ -27,113 +27,42 @@ def __init__(self, *, config: DnnCrfConfig = None, data_path: str = '', dtype: t
2727
if mode == 'train':
2828
self.input = tf.placeholder(tf.int32, [self.batch_size, self.batch_length, self.windows_size])
2929
self.real_indices = tf.placeholder(tf.int32, [self.batch_size, self.batch_length])
30-
self.seq_length = tf.placeholder(tf.int32, [self.batch_size])
3130
else:
3231
self.input = tf.placeholder(tf.int32, [None, self.windows_size])
3332

33+
self.seq_length = tf.placeholder(tf.int32, [None])
34+
3435
# 查找表层
3536
self.embedding_layer = self.get_embedding_layer()
3637
# 隐藏层
3738
if nn == 'mlp':
3839
self.hidden_layer = self.get_mlp_layer(tf.transpose(self.embedding_layer))
3940
elif nn == 'lstm':
4041
self.hidden_layer = self.get_lstm_layer(self.embedding_layer)
42+
elif nn == 'bilstm':
43+
self.hidden_layer = self.get_bilstm_layer(self.embedding_layer)
4144
elif nn == 'gru':
42-
self.hidden_layer = self.get_gru_layer(tf.transpose(self.embedding_layer))
45+
self.hidden_layer = self.get_gru_layer(self.embedding_layer)
4346
else:
44-
self.hidden_layer = self.get_rnn_layer(tf.transpose(self.embedding_layer))
47+
self.hidden_layer = self.get_rnn_layer(self.embedding_layer)
4548
# 输出层
4649
self.output = self.get_output_layer(self.hidden_layer)
4750

4851
if mode == 'predict':
49-
self.output = tf.squeeze(tf.transpose(self.output), axis=2)
52+
if predict != 'll':
53+
self.output = tf.squeeze(tf.transpose(self.output), axis=2)
54+
self.seq, self.best_score = tf.contrib.crf.crf_decode(self.output, self.transition, self.seq_length)
5055
self.sess = tf.Session()
5156
self.sess.run(tf.global_variables_initializer())
5257
tf.train.Saver().restore(save_path=self.model_path, sess=self.sess)
53-
elif train == 'll':
54-
self.ll_loss, _ = tf.contrib.crf.crf_log_likelihood(self.output, self.real_indices, self.seq_length,
55-
self.transition)
56-
self.optimizer = tf.train.AdagradOptimizer(self.learning_rate)
57-
self.train_ll = self.optimizer.minimize(-self.ll_loss)
5858
else:
59-
# 构建训练函数
60-
# 训练用placeholder
61-
self.ll_corr = tf.placeholder(tf.int32, shape=[None, 3])
62-
self.ll_curr = tf.placeholder(tf.int32, shape=[None, 3])
63-
self.trans_corr = tf.placeholder(tf.int32, [None, 2])
64-
self.trans_curr = tf.placeholder(tf.int32, [None, 2])
65-
self.trans_init_corr = tf.placeholder(tf.int32, [None, 1])
66-
self.trans_init_curr = tf.placeholder(tf.int32, [None, 1])
67-
# 损失函数
68-
self.loss, self.loss_with_init = self.get_loss()
59+
self.loss, _ = tf.contrib.crf.crf_log_likelihood(self.output, self.real_indices, self.seq_length,
60+
self.transition)
6961
self.optimizer = tf.train.AdagradOptimizer(self.learning_rate)
70-
self.train = self.optimizer.minimize(self.loss)
71-
self.train_with_init = self.optimizer.minimize(self.loss_with_init)
62+
self.new_optimizer = tf.train.AdamOptimizer()
63+
self.train = self.optimizer.minimize(-self.loss)
7264

7365
def fit(self, epochs: int = 100, interval: int = 20):
74-
with tf.Session() as sess:
75-
tf.global_variables_initializer().run()
76-
saver = tf.train.Saver(max_to_keep=100)
77-
for epoch in range(1, epochs + 1):
78-
print('epoch:', epoch)
79-
for _ in range(self.batch_count):
80-
characters, labels, lengths = self.get_batch()
81-
self.fit_batch(characters, labels, lengths, sess)
82-
# if epoch % interval == 0:
83-
model_path = '../dnlp/models/cws{0}.ckpt'.format(epoch)
84-
saver.save(sess, model_path)
85-
self.save_config(model_path)
86-
87-
def fit_batch(self, characters, labels, lengths, sess):
88-
scores = sess.run(self.output, feed_dict={self.input: characters})
89-
transition = self.transition.eval(session=sess)
90-
transition_init = self.transition_init.eval(session=sess)
91-
update_labels_pos = None
92-
update_labels_neg = None
93-
current_labels = []
94-
trans_pos_indices = []
95-
trans_neg_indices = []
96-
trans_init_pos_indices = []
97-
trans_init_neg_indices = []
98-
for i in range(self.batch_size):
99-
current_label = self.viterbi(scores[:, :lengths[i], i], transition, transition_init)
100-
current_labels.append(current_label)
101-
diff_tag = np.subtract(labels[i, :lengths[i]], current_label)
102-
update_index = np.where(diff_tag != 0)[0]
103-
update_length = len(update_index)
104-
if update_length == 0:
105-
continue
106-
update_label_pos = np.stack([labels[i, update_index], update_index, i * np.ones([update_length])], axis=-1)
107-
update_label_neg = np.stack([current_label[update_index], update_index, i * np.ones([update_length])], axis=-1)
108-
if update_labels_pos is not None:
109-
np.concatenate((update_labels_pos, update_label_pos))
110-
np.concatenate((update_labels_neg, update_label_neg))
111-
else:
112-
update_labels_pos = update_label_pos
113-
update_labels_neg = update_label_neg
114-
115-
trans_pos_index, trans_neg_index, trans_init_pos, trans_init_neg, update_init = self.generate_transition_update_index(
116-
labels[i, :lengths[i]], current_labels[i])
117-
118-
trans_pos_indices.extend(trans_pos_index)
119-
trans_neg_indices.extend(trans_neg_index)
120-
121-
if update_init:
122-
trans_init_pos_indices.append(trans_init_pos)
123-
trans_init_neg_indices.append(trans_init_neg)
124-
125-
if update_labels_pos is not None and update_labels_neg is not None:
126-
feed_dict = {self.input: characters, self.ll_curr: update_labels_neg, self.ll_corr: update_labels_pos,
127-
self.trans_curr: trans_neg_indices, self.trans_corr: trans_pos_indices}
128-
129-
if not trans_init_pos_indices:
130-
sess.run(self.train, feed_dict)
131-
else:
132-
feed_dict[self.trans_init_corr] = trans_init_pos_indices
133-
feed_dict[self.trans_init_curr] = trans_init_neg_indices
134-
sess.run(self.train_with_init, feed_dict)
135-
136-
def fit_ll(self, epochs: int = 100, interval: int = 20):
13766
with tf.Session() as sess:
13867
tf.global_variables_initializer().run()
13968
saver = tf.train.Saver(max_to_keep=epochs)
@@ -143,44 +72,13 @@ def fit_ll(self, epochs: int = 100, interval: int = 20):
14372
characters, labels, lengths = self.get_batch()
14473
# scores = sess.run(self.output, feed_dict={self.input: characters})
14574
feed_dict = {self.input: characters, self.real_indices: labels, self.seq_length: lengths}
146-
sess.run(self.train_ll, feed_dict=feed_dict)
75+
sess.run(self.train, feed_dict=feed_dict)
14776
# self.fit_batch(characters, labels, lengths, sess)
14877
# if epoch % interval == 0:
14978
model_path = '../dnlp/models/cws{0}.ckpt'.format(epoch)
15079
saver.save(sess, model_path)
15180
self.save_config(model_path)
15281

153-
def fit_batch_ll(self):
154-
pass
155-
156-
def generate_transition_update_index(self, correct_labels, current_labels):
157-
if correct_labels.shape != current_labels.shape:
158-
print('sequence length is not equal')
159-
return None
160-
161-
before_corr = correct_labels[0]
162-
before_curr = current_labels[0]
163-
update_init = False
164-
165-
trans_init_pos = None
166-
trans_init_neg = None
167-
trans_pos = []
168-
trans_neg = []
169-
170-
if before_corr != before_curr:
171-
trans_init_pos = [before_corr]
172-
trans_init_neg = [before_curr]
173-
update_init = True
174-
175-
for _, (corr_label, curr_label) in enumerate(zip(correct_labels[1:], current_labels[1:])):
176-
if corr_label != curr_label or before_corr != before_curr:
177-
trans_pos.append([before_corr, corr_label])
178-
trans_neg.append([before_curr, curr_label])
179-
before_corr = corr_label
180-
before_curr = curr_label
181-
182-
return trans_pos, trans_neg, trans_init_pos, trans_init_neg, update_init
183-
18482
def predict(self, sentence: str, return_labels=False):
18583
if self.mode != 'predict':
18684
raise Exception('mode is not allowed to predict')
@@ -194,6 +92,22 @@ def predict(self, sentence: str, return_labels=False):
19492
else:
19593
return self.tags2words(sentence, labels), self.tag2sequences(labels)
19694

95+
def predict_ll(self, sentence: str, return_labels=False):
96+
if self.mode != 'predict':
97+
raise Exception('mode is not allowed to predict')
98+
99+
input = self.indices2input(self.sentence2indices(sentence))
100+
runner = [self.seq, self.best_score, self.output, self.transition]
101+
labels, best_score, output, trans = self.sess.run(runner,
102+
feed_dict={self.input: input, self.seq_length: [len(sentence)]})
103+
# print(output)
104+
# print(trans)
105+
labels = np.squeeze(labels)
106+
if return_labels:
107+
return self.tags2words(sentence, labels), self.tag2sequences(labels)
108+
else:
109+
return self.tags2words(sentence, labels)
110+
197111
def get_embedding_layer(self) -> tf.Tensor:
198112
embeddings = self.__get_variable([self.dict_size, self.embed_size], 'embeddings')
199113
self.params.append(embeddings)
@@ -215,19 +129,27 @@ def get_rnn_layer(self, layer: tf.Tensor) -> tf.Tensor:
215129
rnn = tf.nn.rnn_cell.RNNCell(self.hidden_units)
216130
rnn_output, rnn_out_state = tf.nn.dynamic_rnn(rnn, layer, dtype=self.dtype)
217131
self.params += [v for v in tf.global_variables() if v.name.startswith('rnn')]
218-
return tf.transpose(rnn_output)
132+
return rnn_output
219133

220134
def get_lstm_layer(self, layer: tf.Tensor) -> tf.Tensor:
221135
lstm = tf.nn.rnn_cell.LSTMCell(self.hidden_units)
222136
lstm_output, lstm_out_state = tf.nn.dynamic_rnn(lstm, layer, dtype=self.dtype)
223137
self.params += [v for v in tf.global_variables() if v.name.startswith('rnn')]
224138
return lstm_output
225139

140+
def get_bilstm_layer(self, layer: tf.Tensor) -> tf.Tensor:
141+
lstm_fw = tf.nn.rnn_cell.LSTMCell(self.hidden_units//2)
142+
lstm_bw = tf.nn.rnn_cell.LSTMCell(self.hidden_units//2)
143+
bilstm_output, bilstm_output_state = tf.nn.bidirectional_dynamic_rnn(lstm_fw, lstm_bw, layer, self.seq_length,
144+
dtype=self.dtype)
145+
self.params += [v for v in tf.global_variables() if v.name.startswith('rnn')]
146+
return tf.concat([bilstm_output[0],bilstm_output[1]],-1)
147+
226148
def get_gru_layer(self, layer: tf.Tensor) -> tf.Tensor:
227149
gru = tf.nn.rnn_cell.GRUCell(self.hidden_units)
228150
gru_output, gru_out_state = tf.nn.dynamic_rnn(gru, layer, dtype=self.dtype)
229151
self.params += [v for v in tf.global_variables() if v.name.startswith('rnn')]
230-
return tf.transpose(gru_output)
152+
return gru_output
231153

232154
def get_dropout_layer(self, layer: tf.Tensor) -> tf.Tensor:
233155
return tf.layers.dropout(layer, self.dropout_rate)
@@ -238,17 +160,5 @@ def get_output_layer(self, layer: tf.Tensor) -> tf.Tensor:
238160
self.params += [output_weight, output_bias]
239161
return tf.tensordot(layer, output_weight, [[2], [0]]) + output_bias
240162

241-
def get_loss(self) -> (tf.Tensor, tf.Tensor):
242-
output_loss = tf.reduce_sum(tf.gather_nd(self.output, self.ll_curr) - tf.gather_nd(self.output, self.ll_corr))
243-
trans_loss = tf.gather_nd(self.transition, self.trans_curr) - tf.gather_nd(self.transition, self.trans_corr)
244-
trans_i_curr = tf.gather_nd(self.transition_init, self.trans_init_curr)
245-
trans_i_corr = tf.gather_nd(self.transition_init, self.trans_init_corr)
246-
trans_init_loss = tf.reduce_sum(trans_i_curr - trans_i_corr)
247-
loss = output_loss + trans_loss
248-
regu = tf.contrib.layers.apply_regularization(tf.contrib.layers.l2_regularizer(self.lam), self.params)
249-
l1 = loss + regu
250-
l2 = l1 + trans_init_loss
251-
return l1, l2
252-
253163
def __get_variable(self, size, name) -> tf.Variable:
254164
return tf.Variable(tf.truncated_normal(size, stddev=1.0 / math.sqrt(size[-1]), dtype=self.dtype), name=name)

0 commit comments

Comments
 (0)