Skip to content

Commit 7b7b828

Browse files
committed
Add multichannel Implementation
1 parent 471cb69 commit 7b7b828

File tree

3 files changed

+24
-6
lines changed

3 files changed

+24
-6
lines changed

README.md

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,19 @@ python3 main.py
4343
Batch[1900] - loss: 0.011894 acc: 100.0000%(128/128)
4444
Evaluation - loss: 0.000018 acc: 95.0000%(6679/7000)
4545
early stop by 1000 steps, acc: 95.0000%
46-
- [x] CNN-static 微调预训练的静态词向量
46+
- [x] CNN-static 微调预训练的词向量
4747
```bash
48-
python main.py -static=true -no-static=true
48+
python main.py -static=true -non-static=true
4949
```
5050
>
5151
Batch[1500] - loss: 0.008823 acc: 99.0000%(127/128))
5252
Evaluation - loss: 0.000016 acc: 96.0000%(6729/7000)
5353
early stop by 1000 steps, acc: 96.0000%
54-
- [ ] CNN-multichannel
54+
- [x] CNN-multichannel 微调加静态
55+
```bash
56+
python main.py -static=true -non-static=true -multichannel=true
57+
```
58+
>
59+
Batch[1500] - loss: 0.023020 acc: 98.0000%(126/128))
60+
Evaluation - loss: 0.000016 acc: 96.0000%(6744/7000)
61+
early stop by 1000 steps, acc: 96.0000%

main.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
parser.add_argument('-static', type=bool, default=False, help='whether to use static pre-trained word vectors')
3232
parser.add_argument('-non-static', type=bool, default=False, help='whether to fine-tune static pre-trained word vectors')
33+
parser.add_argument('-multichannel', type=bool, default=False, help='whether to use 2 channel of word vectors')
3334
parser.add_argument('-pretrained-name', type=str, default='sgns.zhihu.word',
3435
help='filename of pre-trained word vectors')
3536
parser.add_argument('-pretrained-path', type=str, default='pretrained', help='path of pre-trained word vectors')
@@ -72,6 +73,9 @@ def load_dataset(text_field, label_field, args, **kwargs):
7273
if args.static:
7374
args.embedding_dim = text_field.vocab.vectors.size()[-1]
7475
args.vectors = text_field.vocab.vectors
76+
if args.multichannel:
77+
args.static = True
78+
args.non_static = True
7579
args.class_num = len(label_field.vocab)
7680
args.cuda = args.device != -1 and torch.cuda.is_available()
7781
args.filter_sizes = [int(size) for size in args.filter_sizes.split(',')]

model.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,22 @@ def __init__(self, args):
1818
self.embedding = nn.Embedding(vocabulary_size, embedding_dimension)
1919
if args.static:
2020
self.embedding = self.embedding.from_pretrained(args.vectors, freeze=not args.non_static)
21-
21+
if args.multichannel:
22+
self.embedding2 = nn.Embedding(vocabulary_size, embedding_dimension).from_pretrained(args.vectors)
23+
chanel_num += 1
24+
else:
25+
self.embedding2 = None
2226
self.convs = nn.ModuleList(
2327
[nn.Conv2d(chanel_num, filter_num, (size, embedding_dimension)) for size in filter_sizes])
2428
self.dropout = nn.Dropout(args.dropout)
2529
self.fc = nn.Linear(len(filter_sizes) * filter_num, class_num)
2630

2731
def forward(self, x):
28-
x = self.embedding(x)
29-
x = x.unsqueeze(1)
32+
if self.embedding2:
33+
x = torch.stack([self.embedding(x), self.embedding2(x)], dim=1)
34+
else:
35+
x = self.embedding(x)
36+
x = x.unsqueeze(1)
3037
x = [F.relu(conv(x)).squeeze(3) for conv in self.convs]
3138
x = [F.max_pool1d(item, item.size(2)).squeeze(2) for item in x]
3239
x = torch.cat(x, 1)

0 commit comments

Comments
 (0)