Skip to content

Commit c957bc6

Browse files
modify init and preprocess logic
1 parent a793973 commit c957bc6

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

python/dnlp/data_process/process_cws.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
import re
33
import pickle
44
from dnlp.data_process.processor import Preprocessor
5-
from dnlp.utils.constant import TAG_BEGIN, TAG_INSIDE, TAG_END, TAG_SINGLE,CWS_TAGS
5+
from dnlp.utils.constant import TAG_BEGIN, TAG_INSIDE, TAG_END, TAG_SINGLE,CWS_TAGS,UNK_VAL
66

77

88
class ProcessCWS(Preprocessor):
99
def __init__(self, *, files: tuple = (), dict_path: str = '', base_folder: str = 'dnlp/data', name: str = '',
10-
delimiter: tuple = ('。')):
10+
mode:str='train',delimiter: tuple = ('。')):
11+
self.mode = mode
1112
self.SPLIT_CHAR = ' '
1213
if base_folder == '':
1314
raise Exception('base folder is empty')
@@ -53,17 +54,25 @@ def map_to_indices(self):
5354
lls = []
5455
for word in words:
5556
if len(word) == 1:
56-
chs.append(self.dictionary[word])
57+
if self.mode == 'train':
58+
chs.append(self.dictionary[word] if self.dictionary.get(word) is not None else UNK_VAL)
59+
else:
60+
chs.append(word)
5761
lls.append(TAG_SINGLE)
5862
elif len(word) == 0:
5963
raise Exception('word length is zero')
6064
else:
61-
chs.extend(map(lambda ch: self.dictionary[ch], word))
65+
if self.mode == 'train':
66+
chs.extend(map(lambda ch: self.dictionary[ch] if self.dictionary.get(ch) is not None else UNK_VAL, word))
67+
else:
68+
chs.append(word)
6269
lls.append(TAG_BEGIN)
6370
lls.extend([TAG_INSIDE] * (len(word) - 2))
6471
lls.append(TAG_END)
6572
characters.append(chs)
6673
labels.append(lls)
74+
if self.mode == 'test':
75+
characters = list(map(lambda words:''.join(words),characters))
6776
return characters, labels
6877

6978
def save_data(self):

python/scripts/init_datasets.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,18 @@ def copy():
99
dst_base_folder = '../dnlp/data/cws/'
1010
if not os.path.exists(dst_base_folder):
1111
os.makedirs(dst_base_folder)
12-
pku = 'pku_training.utf8'
13-
copyfile(src_folder + pku, dst_base_folder + pku)
12+
files = ['pku_training.utf8','pku_test.utf8']
13+
for f in files:
14+
copyfile(src_folder + f, dst_base_folder + f)
1415

1516

1617
def build_cws_datasets():
17-
files = ('pku_training.utf8',)
1818
base_folder = '../dnlp/data/cws/'
1919
if not os.path.exists(base_folder):
2020
os.makedirs(base_folder)
21-
ProcessCWS(files=files, base_folder=base_folder, name='pku_training')
22-
21+
ProcessCWS(files=('pku_training.utf8',), base_folder=base_folder, name='pku_training')
22+
dict_path = base_folder + 'pku_training_dict.utf8'
23+
ProcessCWS(files=('pku_test.utf8',), dict_path=dict_path,base_folder=base_folder, name='pku_test',mode='test')
2324

2425
if __name__ == '__main__':
2526
copy()

0 commit comments

Comments
 (0)