1
+ import numpy as np
2
+ import re
3
+ from collections import defaultdict
4
+
5
+ def clean_str (string ):
6
+ """
7
+ Tokenization/string cleaning for all datasets except for SST.
8
+ Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
9
+ """
10
+ string = re .sub (r"[^A-Za-z0-9(),!?\'\`]" , " " , string )
11
+ string = re .sub (r"\'s" , " \' s" , string )
12
+ string = re .sub (r"\'ve" , " \' ve" , string )
13
+ string = re .sub (r"n\'t" , " n\' t" , string )
14
+ string = re .sub (r"\'re" , " \' re" , string )
15
+ string = re .sub (r"\'d" , " \' d" , string )
16
+ string = re .sub (r"\'ll" , " \' ll" , string )
17
+ string = re .sub (r"," , " , " , string )
18
+ string = re .sub (r"!" , " ! " , string )
19
+ string = re .sub (r"\(" , " \( " , string )
20
+ string = re .sub (r"\)" , " \) " , string )
21
+ string = re .sub (r"\?" , " \? " , string )
22
+ string = re .sub (r"\s{2,}" , " " , string )
23
+ return string .strip ().lower ()
24
+
25
+ def load_data (pos_path , neg_path ):
26
+ sents = []
27
+ labels = []
28
+ vocab = defaultdict (int )
29
+ with open (pos_path , "r" ) as (pos_file
30
+ ), open (neg_path , "r" ) as (neg_file
31
+ ):
32
+ for l in pos_file :
33
+ s = clean_str (l .strip ())
34
+ words = set (s .split ())
35
+ for w in words :
36
+ vocab [w ] += 1
37
+ sents .append (s )
38
+ labels .append ((1 , 0 ))
39
+ for l in neg_file :
40
+ s = clean_str (l .strip ())
41
+ words = set (s .split ())
42
+ for w in words :
43
+ vocab [w ] += 1
44
+ sents .append (s )
45
+ labels .append ((0 , 1 ))
46
+ return list (zip (sents , labels )), vocab
47
+
48
+ def get_word_idx_map (vocab ):
49
+ word_idx_map = {"<PAD>" : 0 }
50
+ i = 1
51
+ for w in vocab :
52
+ word_idx_map [w ] = i
53
+ i += 1
54
+ return word_idx_map
55
+
56
+ def fetch_batch (data , batch_index , batch_size ):
57
+ data_size = len (data )
58
+ # num_batches = int(np.ceil(data_size / batch_size))
59
+ beg = batch_index * batch_size
60
+ end = min ((batch_index + 1 ) * batch_size , data_size )
61
+ return data [beg : end ]
62
+
0 commit comments