forked from crux82/ganbert
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmake_training_examples.py
77 lines (58 loc) · 2.51 KB
/
make_training_examples.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Author: Utkarsh Patel
#
# preparing labeled.tsv and unlabeled.tsv
from Preprocess import Preprocess
import pandas as pd
import numpy as np
import argparse
import os
from tqdm import tqdm
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--fole", default=None, type=float, required=True, help="ratio of training set to be used as labeled examples")
parser.add_argument("--foule", default=None, type=float, required=True, help="ratio of training set to be used as unlabeled examples")
parser.add_argument("--indir", default=None, type=str, required=True, help="directory of json files containing comments")
parser.add_argument("--outdir", default=None, type=str, required=True, help="directory to store labeled.tsv and unlabeled.tsv files")
args = parser.parse_args()
pr = Preprocess()
reader = pd.read_json(args.indir, lines=True, compression=None)
comments = list(reader['body'])
violated_rule = list(reader['violated_rule'])
labeled_comments = comments[: int(args.fole * len(comments))]
unlabeled_comments = comments[int(args.fole * len(comments)) : int((args.fole + args.foule) * len(comments))]
print(f'Number of labeled examples : {len(labeled_comments)}')
print(f'Number of unlabeled examples: {len(unlabeled_comments)}')
labeled_rule = violated_rule[: int(args.fole * len(comments))]
unlabeled_rule = violated_rule[int(args.fole * len(comments)) : int((args.fole + args.foule) * len(comments))]
writer_addr = os.path.join(args.outdir, 'labeled.tsv')
writer = open(writer_addr, 'w')
writer.write('label comments\n')
labeled_ah = 0
labeled_none = 0
print('Preparing labeled.tsv...')
for i in tqdm(range(len(labeled_comments)), unit=" comments", desc="comments processed"):
label = 'NONE'
labeled_none += 1
if labeled_rule[i] == 2:
label = 'AH'
labeled_ah += 1
labeled_none -= 1
cur = pr.preprocess(labeled_comments[i])
cur = ' '.join(cur)
cur = label + ' ' + cur + '\n'
writer.write(cur)
writer.close()
writer_addr = os.path.join(args.outdir, 'unlabeled.tsv')
writer = open(writer_addr, 'w')
writer.write('label comments\n')
print('Preparing unlabeled.tsv...')
for i in tqdm(range(len(unlabeled_comments)), unit=" comments", desc="comments processed"):
label = 'UNK'
cur = pr.preprocess(unlabeled_comments[i])
cur = ' '.join(cur)
cur = label + ' ' + cur + '\n'
writer.write(cur)
writer.close()
print(f'Training set: AH: {labeled_ah} - NONE: {labeled_none}')
if __name__ == '__main__':
main()