Skip to content

Commit 471cb69

Browse files
committed
Add Non-Static Implementation
1 parent 94580a0 commit 471cb69

File tree

4 files changed

+13
-3
lines changed

4 files changed

+13
-3
lines changed

README.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,18 @@ python3 main.py
3737
early stop by 1000 steps, acc: 94.0000%
3838
- [x] CNN-static 使用预训练的静态词向量
3939
```bash
40-
python main.py -static
40+
python main.py -static=true
4141
```
4242
>
4343
Batch[1900] - loss: 0.011894 acc: 100.0000%(128/128)
4444
Evaluation - loss: 0.000018 acc: 95.0000%(6679/7000)
4545
early stop by 1000 steps, acc: 95.0000%
46-
46+
- [x] CNN-static 微调预训练的静态词向量
47+
```bash
48+
python main.py -static=true -no-static=true
49+
```
50+
>
51+
Batch[1500] - loss: 0.008823 acc: 99.0000%(127/128))
52+
Evaluation - loss: 0.000016 acc: 96.0000%(6729/7000)
53+
early stop by 1000 steps, acc: 96.0000%
4754
- [ ] CNN-multichannel

dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import re
22
from torchtext import data
33
import jieba
4+
import logging
5+
jieba.setLogLevel(logging.INFO)
46

57
regex = re.compile(r'[^\u4e00-\u9fa5aA-Za-z0-9]')
68

main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
help='comma-separated filter sizes to use for convolution')
3030

3131
parser.add_argument('-static', type=bool, default=False, help='whether to use static pre-trained word vectors')
32+
parser.add_argument('-non-static', type=bool, default=False, help='whether to fine-tune static pre-trained word vectors')
3233
parser.add_argument('-pretrained-name', type=str, default='sgns.zhihu.word',
3334
help='filename of pre-trained word vectors')
3435
parser.add_argument('-pretrained-path', type=str, default='pretrained', help='path of pre-trained word vectors')

model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self, args):
1717
embedding_dimension = args.embedding_dim
1818
self.embedding = nn.Embedding(vocabulary_size, embedding_dimension)
1919
if args.static:
20-
self.embedding = self.embedding.from_pretrained(args.vectors)
20+
self.embedding = self.embedding.from_pretrained(args.vectors, freeze=not args.non_static)
2121

2222
self.convs = nn.ModuleList(
2323
[nn.Conv2d(chanel_num, filter_num, (size, embedding_dimension)) for size in filter_sizes])

0 commit comments

Comments
 (0)