1
1
import argparse
2
2
import torch
3
3
import torchtext .data as data
4
+ from torchtext .vocab import Vectors
5
+
4
6
import model
5
7
import train
6
8
import dataset
14
16
help = 'how many steps to wait before logging training status [default: 1]' )
15
17
parser .add_argument ('-test-interval' , type = int , default = 100 ,
16
18
help = 'how many steps to wait before testing [default: 100]' )
17
- parser .add_argument ('-save-interval' , type = int , default = 500 , help = 'how many steps to wait before saving [default:500]' )
18
19
parser .add_argument ('-save-dir' , type = str , default = 'snapshot' , help = 'where to save the snapshot' )
19
20
parser .add_argument ('-early-stopping' , type = int , default = 1000 ,
20
21
help = 'iteration numbers to stop without performance increasing' )
26
27
parser .add_argument ('-filter-num' , type = int , default = 100 , help = 'number of each size of filter' )
27
28
parser .add_argument ('-filter-sizes' , type = str , default = '3,4,5' ,
28
29
help = 'comma-separated filter sizes to use for convolution' )
30
+
31
+ parser .add_argument ('-static' , type = bool , default = False , help = 'whether to use static pre-trained word vectors' )
32
+ parser .add_argument ('-pretrained-name' , type = str , default = 'sgns.zhihu.word' ,
33
+ help = 'filename of pre-trained word vectors' )
34
+ parser .add_argument ('-pretrained-path' , type = str , default = 'pretrained' , help = 'path of pre-trained word vectors' )
35
+
29
36
# device
30
37
parser .add_argument ('-device' , type = int , default = - 1 , help = 'device to use for iterate data, -1 mean cpu [default: -1]' )
31
38
34
41
args = parser .parse_args ()
35
42
36
43
37
- def load_dataset (text_field , label_field , ** kwargs ):
44
+ def load_word_vectors (model_name , model_path ):
45
+ vectors = Vectors (name = model_name , cache = model_path )
46
+ return vectors
47
+
48
+
49
+ def load_dataset (text_field , label_field , args , ** kwargs ):
38
50
train_dataset , dev_dataset = dataset .get_dataset ('data' , text_field , label_field )
39
- text_field .build_vocab (train_dataset , dev_dataset )
51
+ if args .static and args .pretrained_name and args .pretrained_path :
52
+ vectors = load_word_vectors (args .pretrained_name , args .pretrained_path )
53
+ text_field .build_vocab (train_dataset , dev_dataset , vectors = vectors )
54
+ else :
55
+ text_field .build_vocab (train_dataset , dev_dataset )
40
56
label_field .build_vocab (train_dataset , dev_dataset )
41
57
train_iter , dev_iter = data .Iterator .splits (
42
58
(train_dataset , dev_dataset ),
@@ -46,19 +62,24 @@ def load_dataset(text_field, label_field, **kwargs):
46
62
return train_iter , dev_iter
47
63
48
64
49
- print (" Loading data..." )
65
+ print (' Loading data...' )
50
66
text_field = data .Field (lower = True )
51
67
label_field = data .Field (sequential = False )
52
- train_iter , dev_iter = load_dataset (text_field , label_field , device = - 1 , repeat = False , shuffle = True )
68
+ train_iter , dev_iter = load_dataset (text_field , label_field , args , device = - 1 , repeat = False , shuffle = True )
53
69
54
70
args .vocabulary_size = len (text_field .vocab )
71
+ if args .static :
72
+ args .embedding_dim = text_field .vocab .vectors .size ()[- 1 ]
73
+ args .vectors = text_field .vocab .vectors
55
74
args .class_num = len (label_field .vocab )
56
75
args .cuda = args .device != - 1 and torch .cuda .is_available ()
57
76
args .filter_sizes = [int (size ) for size in args .filter_sizes .split (',' )]
58
77
59
- print (" Parameters:" )
78
+ print (' Parameters:' )
60
79
for attr , value in sorted (args .__dict__ .items ()):
61
- print ("\t {}={}" .format (attr .upper (), value ))
80
+ if attr in {'vectors' }:
81
+ continue
82
+ print ('\t {}={}' .format (attr .upper (), value ))
62
83
63
84
text_cnn = model .TextCNN (args )
64
85
if args .snapshot :
0 commit comments