Skip to content

Commit 8ec9bdd

Browse files
committed
add static word vectors version
1 parent 3940cf4 commit 8ec9bdd

File tree

5 files changed

+40
-13
lines changed

5 files changed

+40
-13
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,6 @@ trees
102102

103103
# pycharm
104104
.idea
105+
106+
#word2vec
107+
sgns.zhihu.word

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,6 @@ python3 main.py
1818
```
1919

2020
## 结构
21-
-[x] CNN-non-static
22-
-[ ] CNN-static
21+
-[x] CNN-non-static 随机初始化Embedding
22+
-[x] CNN-static 使用预训练的静态词向量
2323
-[ ] CNN-multichannel

main.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import argparse
22
import torch
33
import torchtext.data as data
4+
from torchtext.vocab import Vectors
5+
46
import model
57
import train
68
import dataset
@@ -14,7 +16,6 @@
1416
help='how many steps to wait before logging training status [default: 1]')
1517
parser.add_argument('-test-interval', type=int, default=100,
1618
help='how many steps to wait before testing [default: 100]')
17-
parser.add_argument('-save-interval', type=int, default=500, help='how many steps to wait before saving [default:500]')
1819
parser.add_argument('-save-dir', type=str, default='snapshot', help='where to save the snapshot')
1920
parser.add_argument('-early-stopping', type=int, default=1000,
2021
help='iteration numbers to stop without performance increasing')
@@ -26,6 +27,12 @@
2627
parser.add_argument('-filter-num', type=int, default=100, help='number of each size of filter')
2728
parser.add_argument('-filter-sizes', type=str, default='3,4,5',
2829
help='comma-separated filter sizes to use for convolution')
30+
31+
parser.add_argument('-static', type=bool, default=False, help='whether to use static pre-trained word vectors')
32+
parser.add_argument('-pretrained-name', type=str, default='sgns.zhihu.word',
33+
help='filename of pre-trained word vectors')
34+
parser.add_argument('-pretrained-path', type=str, default='pretrained', help='path of pre-trained word vectors')
35+
2936
# device
3037
parser.add_argument('-device', type=int, default=-1, help='device to use for iterate data, -1 mean cpu [default: -1]')
3138

@@ -34,9 +41,18 @@
3441
args = parser.parse_args()
3542

3643

37-
def load_dataset(text_field, label_field, **kwargs):
44+
def load_word_vectors(model_name, model_path):
45+
vectors = Vectors(name=model_name, cache=model_path)
46+
return vectors
47+
48+
49+
def load_dataset(text_field, label_field, args, **kwargs):
3850
train_dataset, dev_dataset = dataset.get_dataset('data', text_field, label_field)
39-
text_field.build_vocab(train_dataset, dev_dataset)
51+
if args.static and args.pretrained_name and args.pretrained_path:
52+
vectors = load_word_vectors(args.pretrained_name, args.pretrained_path)
53+
text_field.build_vocab(train_dataset, dev_dataset, vectors=vectors)
54+
else:
55+
text_field.build_vocab(train_dataset, dev_dataset)
4056
label_field.build_vocab(train_dataset, dev_dataset)
4157
train_iter, dev_iter = data.Iterator.splits(
4258
(train_dataset, dev_dataset),
@@ -46,19 +62,24 @@ def load_dataset(text_field, label_field, **kwargs):
4662
return train_iter, dev_iter
4763

4864

49-
print("Loading data...")
65+
print('Loading data...')
5066
text_field = data.Field(lower=True)
5167
label_field = data.Field(sequential=False)
52-
train_iter, dev_iter = load_dataset(text_field, label_field, device=-1, repeat=False, shuffle=True)
68+
train_iter, dev_iter = load_dataset(text_field, label_field, args, device=-1, repeat=False, shuffle=True)
5369

5470
args.vocabulary_size = len(text_field.vocab)
71+
if args.static:
72+
args.embedding_dim = text_field.vocab.vectors.size()[-1]
73+
args.vectors = text_field.vocab.vectors
5574
args.class_num = len(label_field.vocab)
5675
args.cuda = args.device != -1 and torch.cuda.is_available()
5776
args.filter_sizes = [int(size) for size in args.filter_sizes.split(',')]
5877

59-
print("Parameters:")
78+
print('Parameters:')
6079
for attr, value in sorted(args.__dict__.items()):
61-
print("\t{}={}".format(attr.upper(), value))
80+
if attr in {'vectors'}:
81+
continue
82+
print('\t{}={}'.format(attr.upper(), value))
6283

6384
text_cnn = model.TextCNN(args)
6485
if args.snapshot:

model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,17 @@ def __init__(self, args):
88
super(TextCNN, self).__init__()
99
self.args = args
1010

11-
vocabulary_size = args.vocabulary_size
12-
embedding_dimension = args.embedding_dim
1311
class_num = args.class_num
1412
chanel_num = 1
1513
filter_num = args.filter_num
1614
filter_sizes = args.filter_sizes
1715

16+
vocabulary_size = args.vocabulary_size
17+
embedding_dimension = args.embedding_dim
1818
self.embedding = nn.Embedding(vocabulary_size, embedding_dimension)
19+
if args.static:
20+
self.embedding = self.embedding.from_pretrained(args.vectors)
21+
1922
self.convs = nn.ModuleList(
2023
[nn.Conv2d(chanel_num, filter_num, (size, embedding_dimension)) for size in filter_sizes])
2124
self.dropout = nn.Dropout(args.dropout)

train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def train(train_iter, dev_iter, model, args):
4343
save(model, args.save_dir, 'best', steps)
4444
else:
4545
if steps - last_step >= args.early_stopping:
46-
print('\nearly stop by {} steps.'.format(args.early_stopping))
47-
break
46+
print('\nearly stop by {} steps, acc: {:.4f}%'.format(args.early_stopping, best_acc))
47+
raise KeyboardInterrupt
4848

4949

5050
def eval(data_iter, model, args):

0 commit comments

Comments
 (0)