Skip to content

Commit 3d037ad

Browse files
committed
Adding dataset and update the files
1 parent e12e521 commit 3d037ad

24 files changed

+50157
-0
lines changed

Pipfile

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
[[source]]
2+
url = "https://pypi.org/simple"
3+
verify_ssl = true
4+
name = "pypi"
5+
6+
[packages]
7+
nltk = "*"
8+
matplotlib = "*"
9+
10+
[requires]
11+
python_version = "3.7"
12+
13+
[dev-packages]

Pipfile.lock

Lines changed: 288 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

__pycache__/loss.cpython-37.pyc

1.21 KB
Binary file not shown.

__pycache__/model.cpython-37.pyc

3.86 KB
Binary file not shown.

__pycache__/settings.cpython-37.pyc

425 Bytes
Binary file not shown.

__pycache__/train.cpython-37.pyc

1.97 KB
Binary file not shown.

__pycache__/utils.cpython-37.pyc

2.72 KB
Binary file not shown.

data/__pycache__/ptb.cpython-37.pyc

5.06 KB
Binary file not shown.

data/__pycache__/utils.cpython-37.pyc

1.27 KB
Binary file not shown.

data/ptb.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import os
2+
import io
3+
import json
4+
import torch
5+
import numpy as np
6+
from collections import defaultdict
7+
from torch.utils.data import Dataset
8+
from nltk.tokenize import TweetTokenizer
9+
10+
from data.utils import OrderedCounter
11+
12+
class PTB(Dataset):
13+
14+
def __init__(self, data_dir, split, create_data, **kwargs):
15+
16+
super().__init__()
17+
self.data_dir = data_dir
18+
self.split = split
19+
self.max_sequence_length = kwargs.get('max_sequence_length', 50)
20+
self.min_occ = kwargs.get('min_occ', 3)
21+
22+
self.raw_data_path = os.path.join(data_dir, 'ptb.'+split+'.txt')
23+
self.data_file = 'ptb.'+split+'.json'
24+
self.vocab_file = 'ptb.vocab.json'
25+
26+
if create_data:
27+
print("Creating new %s ptb data."%split.upper())
28+
self._create_data()
29+
30+
elif not os.path.exists(os.path.join(self.data_dir, self.data_file)):
31+
print("%s preprocessed file not found at %s. Creating new."%(split.upper(), os.path.join(self.data_dir, self.data_file)))
32+
self._create_data()
33+
34+
else:
35+
self._load_data()
36+
37+
38+
def __len__(self):
39+
return len(self.data)
40+
41+
def __getitem__(self, idx):
42+
idx = str(idx)
43+
44+
return {
45+
'input': np.asarray(self.data[idx]['input']),
46+
'target': np.asarray(self.data[idx]['target']),
47+
'length': self.data[idx]['length']
48+
}
49+
50+
@property
51+
def vocab_size(self):
52+
return len(self.w2i)
53+
54+
@property
55+
def pad_idx(self):
56+
return self.w2i['<pad>']
57+
58+
@property
59+
def sos_idx(self):
60+
return self.w2i['<sos>']
61+
62+
@property
63+
def eos_idx(self):
64+
return self.w2i['<eos>']
65+
66+
@property
67+
def unk_idx(self):
68+
return self.w2i['<unk>']
69+
70+
def get_w2i(self):
71+
return self.w2i
72+
73+
def get_i2w(self):
74+
return self.i2w
75+
76+
77+
def _load_data(self, vocab=True):
78+
79+
with open(os.path.join(self.data_dir, self.data_file), 'r') as file:
80+
self.data = json.load(file)
81+
if vocab:
82+
with open(os.path.join(self.data_dir, self.vocab_file), 'r') as file:
83+
vocab = json.load(file)
84+
self.w2i, self.i2w = vocab['w2i'], vocab['i2w']
85+
86+
def _load_vocab(self):
87+
with open(os.path.join(self.data_dir, self.vocab_file), 'r') as vocab_file:
88+
vocab = json.load(vocab_file)
89+
90+
self.w2i, self.i2w = vocab['w2i'], vocab['i2w']
91+
92+
def _create_data(self):
93+
94+
if self.split == 'train':
95+
self._create_vocab()
96+
else:
97+
self._load_vocab()
98+
99+
tokenizer = TweetTokenizer(preserve_case=False)
100+
101+
data = defaultdict(dict)
102+
with open(self.raw_data_path, 'r') as file:
103+
104+
for i, line in enumerate(file):
105+
106+
words = tokenizer.tokenize(line)
107+
108+
input = ['<sos>'] + words
109+
input = input[:self.max_sequence_length]
110+
111+
target = words[:self.max_sequence_length-1]
112+
target = target + ['<eos>']
113+
114+
assert len(input) == len(target), "%i, %i"%(len(input), len(target))
115+
length = len(input)
116+
117+
input.extend(['<pad>'] * (self.max_sequence_length-length))
118+
target.extend(['<pad>'] * (self.max_sequence_length-length))
119+
120+
input = [self.w2i.get(w, self.w2i['<unk>']) for w in input]
121+
target = [self.w2i.get(w, self.w2i['<unk>']) for w in target]
122+
123+
id = len(data)
124+
data[id]['input'] = input
125+
data[id]['target'] = target
126+
data[id]['length'] = length
127+
128+
with io.open(os.path.join(self.data_dir, self.data_file), 'wb') as data_file:
129+
data = json.dumps(data, ensure_ascii=False)
130+
data_file.write(data.encode('utf8', 'replace'))
131+
132+
self._load_data(vocab=False)
133+
134+
def _create_vocab(self):
135+
136+
assert self.split == 'train', "Vocablurary can only be created for training file."
137+
138+
tokenizer = TweetTokenizer(preserve_case=False)
139+
140+
w2c = OrderedCounter()
141+
w2i = dict()
142+
i2w = dict()
143+
144+
special_tokens = ['<pad>', '<unk>', '<sos>', '<eos>']
145+
for st in special_tokens:
146+
i2w[len(w2i)] = st
147+
w2i[st] = len(w2i)
148+
149+
with open(self.raw_data_path, 'r') as file:
150+
151+
for i, line in enumerate(file):
152+
words = tokenizer.tokenize(line)
153+
w2c.update(words)
154+
155+
for w, c in w2c.items():
156+
if c > self.min_occ and w not in special_tokens:
157+
i2w[len(w2i)] = w
158+
w2i[w] = len(w2i)
159+
160+
assert len(w2i) == len(i2w)
161+
162+
print("Vocablurary of %i keys created." %len(w2i))
163+
164+
vocab = dict(w2i=w2i, i2w=i2w)
165+
with io.open(os.path.join(self.data_dir, self.vocab_file), 'wb') as vocab_file:
166+
data = json.dumps(vocab, ensure_ascii=False)
167+
vocab_file.write(data.encode('utf8', 'replace'))
168+
169+
self._load_vocab()

data/ptb.test.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)