Skip to content

Commit e8e0307

Browse files
author
higgs
committed
Implement the model and train procedure.
1 parent 402d2db commit e8e0307

File tree

5 files changed

+10964
-0
lines changed

5 files changed

+10964
-0
lines changed

data_utils.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import numpy as np
2+
import re
3+
from collections import defaultdict
4+
5+
def clean_str(string):
6+
"""
7+
Tokenization/string cleaning for all datasets except for SST.
8+
Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
9+
"""
10+
string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
11+
string = re.sub(r"\'s", " \'s", string)
12+
string = re.sub(r"\'ve", " \'ve", string)
13+
string = re.sub(r"n\'t", " n\'t", string)
14+
string = re.sub(r"\'re", " \'re", string)
15+
string = re.sub(r"\'d", " \'d", string)
16+
string = re.sub(r"\'ll", " \'ll", string)
17+
string = re.sub(r",", " , ", string)
18+
string = re.sub(r"!", " ! ", string)
19+
string = re.sub(r"\(", " \( ", string)
20+
string = re.sub(r"\)", " \) ", string)
21+
string = re.sub(r"\?", " \? ", string)
22+
string = re.sub(r"\s{2,}", " ", string)
23+
return string.strip().lower()
24+
25+
def load_data(pos_path, neg_path):
26+
sents = []
27+
labels = []
28+
vocab = defaultdict(int)
29+
with open(pos_path, "r") as (pos_file
30+
), open(neg_path, "r") as (neg_file
31+
):
32+
for l in pos_file:
33+
s = clean_str(l.strip())
34+
words = set(s.split())
35+
for w in words:
36+
vocab[w] += 1
37+
sents.append(s)
38+
labels.append((1, 0))
39+
for l in neg_file:
40+
s = clean_str(l.strip())
41+
words = set(s.split())
42+
for w in words:
43+
vocab[w] += 1
44+
sents.append(s)
45+
labels.append((0, 1))
46+
return list(zip(sents, labels)), vocab
47+
48+
def get_word_idx_map(vocab):
49+
word_idx_map = {"<PAD>": 0}
50+
i = 1
51+
for w in vocab:
52+
word_idx_map[w] = i
53+
i += 1
54+
return word_idx_map
55+
56+
def fetch_batch(data, batch_index, batch_size):
57+
data_size = len(data)
58+
# num_batches = int(np.ceil(data_size / batch_size))
59+
beg = batch_index * batch_size
60+
end = min((batch_index + 1) * batch_size, data_size)
61+
return data[beg : end]
62+

0 commit comments

Comments
 (0)