Skip to content

Commit f59a0d1

Browse files
committed
Syc codes.
1 parent 03e6fc1 commit f59a0d1

12 files changed

+1449
-118
lines changed

bound.ipynb

+199
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [
8+
{
9+
"name": "stderr",
10+
"output_type": "stream",
11+
"text": [
12+
"# edit distance 1000x1000: 100%|██████████| 1000/1000 [02:04<00:00, 8.02it/s]\n",
13+
" 0%| | 0/52 [00:00<?, ?it/s]"
14+
]
15+
},
16+
{
17+
"name": "stdout",
18+
"output_type": "stream",
19+
"text": [
20+
"# Calculate edit distance time: 124.6652319431305\n"
21+
]
22+
},
23+
{
24+
"name": "stderr",
25+
"output_type": "stream",
26+
"text": [
27+
"100%|██████████| 52/52 [1:32:50<00:00, 91.40s/it] "
28+
]
29+
},
30+
{
31+
"name": "stdout",
32+
"output_type": "stream",
33+
"text": [
34+
"52.0 1.8483353884093712 28.439545176737834\n"
35+
]
36+
},
37+
{
38+
"name": "stderr",
39+
"output_type": "stream",
40+
"text": [
41+
"\n"
42+
]
43+
}
44+
],
45+
"source": [
46+
"import random\n",
47+
"import string\n",
48+
"import numpy as np\n",
49+
"from multiprocessing import cpu_count\n",
50+
"np.random.seed(1)\n",
51+
"random.seed(1)\n",
52+
"\n",
53+
"C = 52 \n",
54+
"M = 1000\n",
55+
"letters = list(range(C))\n",
56+
"\n",
57+
"def randomString(stringLength):\n",
58+
" \"\"\"Generate a random string of fixed length \"\"\"\n",
59+
" return [random.choice(letters) for _ in range(stringLength)]\n",
60+
"\n",
61+
"def int2str(l):\n",
62+
" return \"\".join(chr(i+ord('a')) for i in l)\n",
63+
"\n",
64+
"N = 1000\n",
65+
"strings = [randomString(random.randint(1, M)) for _ in range(N)]\n",
66+
"lengths = [len(i) for i in strings]\n",
67+
"def one_hot(s):\n",
68+
" encode = np.zeros((C, M), dtype=np.int)\n",
69+
" encode[np.array(s), np.arange(len(s))] = 1\n",
70+
" return encode\n",
71+
"\n",
72+
"oh_strs = [one_hot(s) for s in strings]\n",
73+
"or_strs = [int2str(s) for s in strings]\n",
74+
"\n",
75+
"from datasets import all_pair_distance\n",
76+
"knnd = all_pair_distance(or_strs, or_strs, cpu_count())\n",
77+
"\n",
78+
"oh_strs = np.array(oh_strs)\n",
79+
"\n",
80+
"import tqdm\n",
81+
"dist = []\n",
82+
"def int2str(s):\n",
83+
" return \"\".join(str(i) for i in s)\n",
84+
"for i in tqdm.tqdm(range(C)):\n",
85+
" ss = oh_strs[:, i, :]\n",
86+
" ss = [int2str(s[:lengths[i]]) for i, s in enumerate(ss)]\n",
87+
" d = all_pair_distance(ss, ss, 8, progress=False)\n",
88+
" dist.append(d)\n",
89+
"\n",
90+
"dist = np.array(dist)\n",
91+
"bound = np.sum(dist, axis=0)\n",
92+
"index = np.where(knnd != 0)\n",
93+
"ration = bound[index] / knnd[index]\n",
94+
"print(np.max(ration), np.min(ration), np.mean(ration))"
95+
]
96+
},
97+
{
98+
"cell_type": "code",
99+
"execution_count": 2,
100+
"metadata": {},
101+
"outputs": [
102+
{
103+
"name": "stdout",
104+
"output_type": "stream",
105+
"text": [
106+
"(1000, 1000)\n",
107+
"(1000, 1000)\n",
108+
"(array([], dtype=int64), array([], dtype=int64))\n"
109+
]
110+
},
111+
{
112+
"name": "stderr",
113+
"output_type": "stream",
114+
"text": [
115+
"/home/xinyan/.conda/envs/py3/lib/python3.6/site-packages/ipykernel_launcher.py:3: RuntimeWarning: invalid value encountered in true_divide\n",
116+
" This is separate from the ipykernel package so we can avoid doing imports until\n"
117+
]
118+
}
119+
],
120+
"source": [
121+
"print(knnd.shape)\n",
122+
"print(bound.shape)\n",
123+
"idx = np.where( bound/knnd == 1.6344086021505377)\n",
124+
"print(idx)"
125+
]
126+
},
127+
{
128+
"cell_type": "code",
129+
"execution_count": 3,
130+
"metadata": {},
131+
"outputs": [],
132+
"source": [
133+
"pairs = list((or_strs[i], or_strs[j]) for i, j in zip(idx[0], idx[1]))"
134+
]
135+
},
136+
{
137+
"cell_type": "code",
138+
"execution_count": 4,
139+
"metadata": {},
140+
"outputs": [
141+
{
142+
"data": {
143+
"text/plain": [
144+
"[]"
145+
]
146+
},
147+
"execution_count": 4,
148+
"metadata": {},
149+
"output_type": "execute_result"
150+
}
151+
],
152+
"source": [
153+
"pairs"
154+
]
155+
},
156+
{
157+
"cell_type": "code",
158+
"execution_count": null,
159+
"metadata": {},
160+
"outputs": [],
161+
"source": []
162+
},
163+
{
164+
"cell_type": "code",
165+
"execution_count": null,
166+
"metadata": {},
167+
"outputs": [],
168+
"source": []
169+
},
170+
{
171+
"cell_type": "code",
172+
"execution_count": null,
173+
"metadata": {},
174+
"outputs": [],
175+
"source": []
176+
}
177+
],
178+
"metadata": {
179+
"kernelspec": {
180+
"display_name": "Python 3",
181+
"language": "python",
182+
"name": "python3"
183+
},
184+
"language_info": {
185+
"codemirror_mode": {
186+
"name": "ipython",
187+
"version": 3
188+
},
189+
"file_extension": ".py",
190+
"mimetype": "text/x-python",
191+
"name": "python",
192+
"nbconvert_exporter": "python",
193+
"pygments_lexer": "ipython3",
194+
"version": "3.6.8"
195+
}
196+
},
197+
"nbformat": 4,
198+
"nbformat_minor": 1
199+
}

datasets.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,20 @@ def f(x):
1515
return [Levenshtein.distance(a, b) for b in B]
1616

1717

18-
def all_pair_distance(A, B, n_thread):
18+
def all_pair_distance(A, B, n_thread, progress=True):
19+
bar = tqdm if progress else lambda iterable,total,desc : iterable
1920
def all_pair(A, B, n_thread):
2021
with Pool(n_thread) as pool:
2122
start_time = time.time()
2223
edit = list(
23-
tqdm(
24+
bar(
2425
pool.imap(f, zip(A, [B for _ in A])),
2526
total=len(A),
2627
desc="# edit distance {}x{}".format(len(A), len(B)),
2728
)
2829
)
29-
print("# Calculate edit distance time: {}".format(time.time() - start_time))
30+
if progress:
31+
print("# Calculate edit distance time: {}".format(time.time() - start_time))
3032
return np.array(edit)
3133

3234
if len(A) < len(B):
@@ -62,17 +64,20 @@ def word2sig(lines, max_length=None):
6264

6365
all_chars = dict()
6466
all_chars["counter"] = 0
67+
alphabet = ''
6568

6669
def to_ord(c):
6770
nonlocal all_chars
71+
nonlocal alphabet
6872
if not (c in all_chars):
73+
alphabet += c
6974
all_chars[c] = all_chars["counter"]
7075
all_chars["counter"] = all_chars["counter"] + 1
7176
return all_chars[c]
7277

7378
x = [[to_ord(c) for c in line] for line in lines]
7479

75-
return all_chars["counter"], max_length, x
80+
return all_chars["counter"], max_length, x, alphabet
7681

7782

7883
def ivecs_read(file):

distance_estimation_cgk.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import os
2+
import numpy as np
3+
from main import get_args
4+
from nns import linear_fit
5+
from embed_cgk import random_seed, cgk_string, distance
6+
7+
8+
threshold = 1000
9+
10+
args, data_handler, data_file = get_args()
11+
train_dist, query_dist = data_handler.train_dist, data_handler.query_dist
12+
train_idx = np.where(train_dist < threshold)
13+
query_idx = np.where(query_dist < threshold)
14+
15+
dis_dir = "cgk_dist/{}".format(args.dataset)
16+
os.makedirs(dis_dir, exist_ok=True)
17+
if not os.path.isfile(dis_dir + "train_idx.npy"):
18+
h = random_seed(data_handler.M, data_handler.C)
19+
xq = cgk_string(h, data_handler.xq.sig, data_handler.M)
20+
xt = cgk_string(h, data_handler.xt.sig, data_handler.M)
21+
xb = cgk_string(h, data_handler.xb.sig, data_handler.M)
22+
23+
train_dist_hm = distance(xt, xt)
24+
query_dist_hm = distance(xq, xb)
25+
26+
np.save(dis_dir + "train_dist_hm.npy", train_dist_hm)
27+
np.save(dis_dir + "query_dist_hm.npy", query_dist_hm)
28+
else:
29+
train_dist_hm = np.load(dis_dir + "train_dist_hm.npy")
30+
query_dist_hm = np.load(dis_dir + "query_dist_hm.npy")
31+
32+
l2ed_gru = linear_fit(
33+
train_dist_hm[train_idx],
34+
train_dist[train_idx], deg=2)
35+
print(np.mean(np.abs(l2ed_gru(query_dist_hm[query_idx]) / query_dist[query_idx] - 1.0)))

distance_estimation_l2.py

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import numpy as np
2+
from utils import l2_dist
3+
from nns import linear_fit, load_vec, get_args
4+
5+
6+
7+
threshold = 1000
8+
9+
args = get_args()
10+
11+
# args.embed = 'gru'
12+
# xq_gru, xb_gru, xt_gru, train_dist, query_dist = load_vec(args)
13+
# query_dist = query_dist[:, :50000]
14+
# xb_gru = xb_gru[:50000, :]
15+
#
16+
# train_idx = np.where(train_dist < threshold)
17+
# query_idx = np.where(query_dist < threshold)
18+
#
19+
# train_dist_l2_gru = l2_dist(xt_gru, xt_gru)
20+
#
21+
# l2ed_gru = linear_fit(
22+
# train_dist_l2_gru[train_idx],
23+
# train_dist[train_idx], deg=1)
24+
#
25+
# query_dist_l2_gru = l2_dist(xq_gru, xb_gru)
26+
# print(np.mean(np.abs(l2ed_gru(query_dist_l2_gru[query_idx]) / query_dist[query_idx] - 1.0)))
27+
28+
29+
# args.embed = 'cnn'
30+
# xq_cnn, xb_cnn, xt_cnn,train_dist, query_dist = load_vec(args)
31+
# query_dist = query_dist[:, :50000]
32+
# xb_cnn = xb_cnn[:50000, :]
33+
#
34+
# train_idx = np.where(train_dist < threshold)
35+
# query_idx = np.where(query_dist < threshold)
36+
#
37+
# print("# training all pair distance")
38+
# train_dist_l2_cnn = l2_dist(xt_cnn, xt_cnn)
39+
# print("# training all pair distance fitting to edit distance")
40+
# l2ed_cnn = linear_fit(
41+
# train_dist_l2_cnn[train_idx],
42+
# train_dist[train_idx],
43+
# deg=1)
44+
# print("# query all pair distance")
45+
# query_dist_l2_cnn = l2_dist(xq_cnn, xb_cnn)
46+
# print("# fitting errors")
47+
# print(np.mean(np.abs(l2ed_cnn(query_dist_l2_cnn)[query_idx] / query_dist[query_idx] - 1.0)))
48+
49+
50+
import matplotlib.pyplot as plt
51+
fontsize = 44
52+
ticksize = 40
53+
labelsize = 35
54+
legendsize = 30
55+
plt.style.use("seaborn-white")
56+
57+
W = 12.0
58+
H = 9.5
59+
def _plot_setting():
60+
plt.yticks(fontsize=ticksize)
61+
plt.xticks(fontsize=ticksize)
62+
plt.gcf().set_size_inches(W, H)
63+
plt.subplots_adjust(
64+
top=0.976,
65+
bottom=0.141,
66+
left=0.133,
67+
right=0.988,
68+
hspace=0.2,
69+
wspace=0.2
70+
)
71+
print("# plotting")
72+
# idx = np.random.choice(np.size(query_dist[query_idx]), threshold)
73+
74+
# plt.scatter(query_dist[query_idx].reshape(-1)[idx],
75+
# l2ed_gru(query_dist_l2_gru[query_idx].reshape(-1))[idx], color="blue")
76+
77+
# plt.scatter(query_dist[query_idx].reshape(-1)[idx],
78+
# l2ed_cnn(query_dist_l2_cnn[query_idx].reshape(-1))[idx], color="red")
79+
# plt.scatter(query_dist[query_idx].reshape(-1)[idx],
80+
# query_dist[query_idx].reshape(-1)[idx], color="black")
81+
82+
plt.xlim(left=-10, right=threshold)
83+
plt.ylim(bottom=-10, top=threshold)
84+
_plot_setting()
85+
plt.xlabel("True Edit Distance", fontsize=fontsize)
86+
plt.ylabel("Estimated Edit Distance", fontsize=fontsize)
87+
plt.text(
88+
x=0.02, y=0.8, color='blue',
89+
s=args.dataset.upper(),
90+
fontsize=labelsize,
91+
transform=plt.subplot().transAxes
92+
)
93+
# plt.savefig("/home/xinyan/Dropbox/project/Yan Xiao Paper/string-embedding/figures/distance_estimation_{}.pdf".format(args.dataset))
94+
plt.show()

embed_cnn.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,12 @@ def cnn_embedding(args, h, data_file):
5050
xq = _batch_embed(args, model.embedding_net, h.xq, device)
5151
print("# Embedding time: " + str(embed_time))
5252
if args.save_embed:
53-
np.save("{}/embedding_xb".format(data_file), xb)
54-
np.save("{}/embedding_xt".format(data_file), xt)
55-
np.save("{}/embedding_xq".format(data_file), xq)
53+
if args.embed_dir != "":
54+
args.embed_dir = args.embed_dir + "/"
55+
os.makedirs("{}/{}".format(data_file, args.embed_dir), exist_ok=True)
56+
np.save("{}/{}embedding_xb".format(data_file, args.embed_dir), xb)
57+
np.save("{}/{}embedding_xt".format(data_file, args.embed_dir), xt)
58+
np.save("{}/{}embedding_xq".format(data_file, args.embed_dir), xq)
5659

5760
if args.recall:
5861
test_recall(xb, xq, h.query_knn)

0 commit comments

Comments
 (0)