Skip to content

Commit 598b6e8

Browse files
committed
Enable test_insert_vector
1 parent d814283 commit 598b6e8

File tree

1 file changed

+72
-68
lines changed

1 file changed

+72
-68
lines changed

tests/test_insert_vector.py

Lines changed: 72 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,83 +1,87 @@
1-
# #!/usr/bin/env python3
1+
#!/usr/bin/env python3
22

3-
# import time
4-
# import unittest
5-
# import random
6-
# from chdb import session
3+
import time
4+
import unittest
5+
import random
6+
from chdb import session
77

8-
# chs = session.Session()
8+
chs = None
99

1010

11-
# class TestInsertArray(unittest.TestCase):
12-
# def setUp(self) -> None:
13-
# def generate_embedding():
14-
# embedding = [random.uniform(-1, 1) for _ in range(16)]
15-
# return f'"{",".join(str(e) for e in embedding)}"' # format: "[1.0,2.0,3.0,...]"
11+
class TestInsertArray(unittest.TestCase):
12+
def setUp(self) -> None:
13+
def generate_embedding():
14+
embedding = [random.uniform(-1, 1) for _ in range(16)]
15+
return f'"{",".join(str(e) for e in embedding)}"' # format: "[1.0,2.0,3.0,...]"
1616

17-
# with open("data.csv", "w", encoding="utf-8") as file:
18-
# for movieId in range(1, 100001):
19-
# embedding = generate_embedding()
20-
# line = f"{movieId},{embedding}\n"
21-
# file.write(line)
17+
with open("data.csv", "w", encoding="utf-8") as file:
18+
for movieId in range(1, 100001):
19+
embedding = generate_embedding()
20+
line = f"{movieId},{embedding}\n"
21+
file.write(line)
2222

23-
# return super().setUp()
23+
return super().setUp()
2424

25-
# def tearDown(self) -> None:
26-
# return super().tearDown()
25+
def tearDown(self) -> None:
26+
return super().tearDown()
2727

28-
# def test_01_insert_array(self):
29-
# chs.query("CREATE DATABASE IF NOT EXISTS movie_embeddings ENGINE = Atomic")
30-
# chs.query("USE movie_embeddings")
31-
# chs.query("DROP TABLE IF EXISTS embeddings")
32-
# chs.query("DROP TABLE IF EXISTS embeddings_with_title")
28+
def test_01_insert_array(self):
29+
global chs
30+
chs = session.Session()
31+
chs.query("CREATE DATABASE IF NOT EXISTS movie_embeddings ENGINE = Atomic")
32+
chs.query("USE movie_embeddings")
33+
chs.query("DROP TABLE IF EXISTS embeddings")
34+
chs.query("DROP TABLE IF EXISTS embeddings_with_title")
3335

34-
# chs.query(
35-
# """CREATE TABLE embeddings (
36-
# movieId UInt32 NOT NULL,
37-
# embedding Array(Float32) NOT NULL
38-
# ) ENGINE = MergeTree()
39-
# ORDER BY movieId"""
40-
# )
36+
chs.query(
37+
"""CREATE TABLE embeddings (
38+
movieId UInt32 NOT NULL,
39+
embedding Array(Float32) NOT NULL
40+
) ENGINE = MergeTree()
41+
ORDER BY movieId"""
42+
)
4143

42-
# print("Inserting movie embeddings into the database")
43-
# t0 = time.time()
44-
# print(chs.query("INSERT INTO embeddings FROM INFILE 'data.csv' FORMAT CSV"))
45-
# rows = chs.query("SELECT count(*) FROM embeddings")
46-
# print(f"Inserted {rows} rows in {time.time() - t0} seconds")
44+
print("Inserting movie embeddings into the database")
45+
t0 = time.time()
46+
print(chs.query("INSERT INTO embeddings FROM INFILE 'data.csv' FORMAT CSV"))
47+
rows = chs.query("SELECT count(*) FROM embeddings")
48+
print(f"Inserted {rows} rows in {time.time() - t0} seconds")
4749

48-
# print("Select result:", chs.query("SELECT * FROM embeddings LIMIT 5"))
50+
print("Select result:", chs.query("SELECT * FROM embeddings LIMIT 5"))
4951

50-
# def test_02_query_order_by_cosine_distance(self):
51-
# # You can change the 100 to any movieId you want, but that is just an example
52-
# # If you want to see a real world example, please check the
53-
# # `examples/chDB_vector_search.ipynb`
54-
# # the example is based on the MovieLens dataset and embeddings are generated
55-
# # by the Word2Vec algorithm just extract the movie similarity info from
56-
# # users' movie ratings without any extra data.
57-
# topN = chs.query(
58-
# """
59-
# WITH
60-
# 100 AS theMovieId,
61-
# (SELECT embedding FROM embeddings WHERE movieId = theMovieId LIMIT 1) AS targetEmbedding
62-
# SELECT
63-
# movieId,
64-
# cosineDistance(embedding, targetEmbedding) AS distance
65-
# FROM embeddings
66-
# WHERE movieId != theMovieId
67-
# ORDER BY distance ASC
68-
# LIMIT 5
69-
# """
70-
# )
71-
# print(
72-
# f"Scaned {topN.rows_read()} rows, "
73-
# f"Top 5 similar movies to movieId 100 in {topN.elapsed()}"
74-
# )
75-
# print(topN)
52+
def test_02_query_order_by_cosine_distance(self):
53+
# You can change the 100 to any movieId you want, but that is just an example
54+
# If you want to see a real world example, please check the
55+
# `examples/chDB_vector_search.ipynb`
56+
# the example is based on the MovieLens dataset and embeddings are generated
57+
# by the Word2Vec algorithm just extract the movie similarity info from
58+
# users' movie ratings without any extra data.
59+
global chs
60+
topN = chs.query(
61+
"""
62+
WITH
63+
100 AS theMovieId,
64+
(SELECT embedding FROM embeddings WHERE movieId = theMovieId LIMIT 1) AS targetEmbedding
65+
SELECT
66+
movieId,
67+
cosineDistance(embedding, targetEmbedding) AS distance
68+
FROM embeddings
69+
WHERE movieId != theMovieId
70+
ORDER BY distance ASC
71+
LIMIT 5
72+
"""
73+
)
74+
print(
75+
f"Scaned {topN.rows_read()} rows, "
76+
f"Top 5 similar movies to movieId 100 in {topN.elapsed()}"
77+
)
78+
print(topN)
7679

77-
# def test_03_close_session(self):
78-
# chs.close()
79-
# self.assertEqual(chs._conn, None)
80+
def test_03_close_session(self):
81+
global chs
82+
chs.close()
83+
self.assertEqual(chs._conn, None)
8084

8185

82-
# if __name__ == "__main__":
83-
# unittest.main()
86+
if __name__ == "__main__":
87+
unittest.main()

0 commit comments

Comments
 (0)