|
1 |
| -# #!/usr/bin/env python3 |
| 1 | +#!/usr/bin/env python3 |
2 | 2 |
|
3 |
| -# import time |
4 |
| -# import unittest |
5 |
| -# import random |
6 |
| -# from chdb import session |
| 3 | +import time |
| 4 | +import unittest |
| 5 | +import random |
| 6 | +from chdb import session |
7 | 7 |
|
8 |
| -# chs = session.Session() |
| 8 | +chs = None |
9 | 9 |
|
10 | 10 |
|
11 |
| -# class TestInsertArray(unittest.TestCase): |
12 |
| -# def setUp(self) -> None: |
13 |
| -# def generate_embedding(): |
14 |
| -# embedding = [random.uniform(-1, 1) for _ in range(16)] |
15 |
| -# return f'"{",".join(str(e) for e in embedding)}"' # format: "[1.0,2.0,3.0,...]" |
| 11 | +class TestInsertArray(unittest.TestCase): |
| 12 | + def setUp(self) -> None: |
| 13 | + def generate_embedding(): |
| 14 | + embedding = [random.uniform(-1, 1) for _ in range(16)] |
| 15 | + return f'"{",".join(str(e) for e in embedding)}"' # format: "[1.0,2.0,3.0,...]" |
16 | 16 |
|
17 |
| -# with open("data.csv", "w", encoding="utf-8") as file: |
18 |
| -# for movieId in range(1, 100001): |
19 |
| -# embedding = generate_embedding() |
20 |
| -# line = f"{movieId},{embedding}\n" |
21 |
| -# file.write(line) |
| 17 | + with open("data.csv", "w", encoding="utf-8") as file: |
| 18 | + for movieId in range(1, 100001): |
| 19 | + embedding = generate_embedding() |
| 20 | + line = f"{movieId},{embedding}\n" |
| 21 | + file.write(line) |
22 | 22 |
|
23 |
| -# return super().setUp() |
| 23 | + return super().setUp() |
24 | 24 |
|
25 |
| -# def tearDown(self) -> None: |
26 |
| -# return super().tearDown() |
| 25 | + def tearDown(self) -> None: |
| 26 | + return super().tearDown() |
27 | 27 |
|
28 |
| -# def test_01_insert_array(self): |
29 |
| -# chs.query("CREATE DATABASE IF NOT EXISTS movie_embeddings ENGINE = Atomic") |
30 |
| -# chs.query("USE movie_embeddings") |
31 |
| -# chs.query("DROP TABLE IF EXISTS embeddings") |
32 |
| -# chs.query("DROP TABLE IF EXISTS embeddings_with_title") |
| 28 | + def test_01_insert_array(self): |
| 29 | + global chs |
| 30 | + chs = session.Session() |
| 31 | + chs.query("CREATE DATABASE IF NOT EXISTS movie_embeddings ENGINE = Atomic") |
| 32 | + chs.query("USE movie_embeddings") |
| 33 | + chs.query("DROP TABLE IF EXISTS embeddings") |
| 34 | + chs.query("DROP TABLE IF EXISTS embeddings_with_title") |
33 | 35 |
|
34 |
| -# chs.query( |
35 |
| -# """CREATE TABLE embeddings ( |
36 |
| -# movieId UInt32 NOT NULL, |
37 |
| -# embedding Array(Float32) NOT NULL |
38 |
| -# ) ENGINE = MergeTree() |
39 |
| -# ORDER BY movieId""" |
40 |
| -# ) |
| 36 | + chs.query( |
| 37 | + """CREATE TABLE embeddings ( |
| 38 | + movieId UInt32 NOT NULL, |
| 39 | + embedding Array(Float32) NOT NULL |
| 40 | + ) ENGINE = MergeTree() |
| 41 | + ORDER BY movieId""" |
| 42 | + ) |
41 | 43 |
|
42 |
| -# print("Inserting movie embeddings into the database") |
43 |
| -# t0 = time.time() |
44 |
| -# print(chs.query("INSERT INTO embeddings FROM INFILE 'data.csv' FORMAT CSV")) |
45 |
| -# rows = chs.query("SELECT count(*) FROM embeddings") |
46 |
| -# print(f"Inserted {rows} rows in {time.time() - t0} seconds") |
| 44 | + print("Inserting movie embeddings into the database") |
| 45 | + t0 = time.time() |
| 46 | + print(chs.query("INSERT INTO embeddings FROM INFILE 'data.csv' FORMAT CSV")) |
| 47 | + rows = chs.query("SELECT count(*) FROM embeddings") |
| 48 | + print(f"Inserted {rows} rows in {time.time() - t0} seconds") |
47 | 49 |
|
48 |
| -# print("Select result:", chs.query("SELECT * FROM embeddings LIMIT 5")) |
| 50 | + print("Select result:", chs.query("SELECT * FROM embeddings LIMIT 5")) |
49 | 51 |
|
50 |
| -# def test_02_query_order_by_cosine_distance(self): |
51 |
| -# # You can change the 100 to any movieId you want, but that is just an example |
52 |
| -# # If you want to see a real world example, please check the |
53 |
| -# # `examples/chDB_vector_search.ipynb` |
54 |
| -# # the example is based on the MovieLens dataset and embeddings are generated |
55 |
| -# # by the Word2Vec algorithm just extract the movie similarity info from |
56 |
| -# # users' movie ratings without any extra data. |
57 |
| -# topN = chs.query( |
58 |
| -# """ |
59 |
| -# WITH |
60 |
| -# 100 AS theMovieId, |
61 |
| -# (SELECT embedding FROM embeddings WHERE movieId = theMovieId LIMIT 1) AS targetEmbedding |
62 |
| -# SELECT |
63 |
| -# movieId, |
64 |
| -# cosineDistance(embedding, targetEmbedding) AS distance |
65 |
| -# FROM embeddings |
66 |
| -# WHERE movieId != theMovieId |
67 |
| -# ORDER BY distance ASC |
68 |
| -# LIMIT 5 |
69 |
| -# """ |
70 |
| -# ) |
71 |
| -# print( |
72 |
| -# f"Scaned {topN.rows_read()} rows, " |
73 |
| -# f"Top 5 similar movies to movieId 100 in {topN.elapsed()}" |
74 |
| -# ) |
75 |
| -# print(topN) |
| 52 | + def test_02_query_order_by_cosine_distance(self): |
| 53 | + # You can change the 100 to any movieId you want, but that is just an example |
| 54 | + # If you want to see a real world example, please check the |
| 55 | + # `examples/chDB_vector_search.ipynb` |
| 56 | + # the example is based on the MovieLens dataset and embeddings are generated |
| 57 | + # by the Word2Vec algorithm just extract the movie similarity info from |
| 58 | + # users' movie ratings without any extra data. |
| 59 | + global chs |
| 60 | + topN = chs.query( |
| 61 | + """ |
| 62 | + WITH |
| 63 | + 100 AS theMovieId, |
| 64 | + (SELECT embedding FROM embeddings WHERE movieId = theMovieId LIMIT 1) AS targetEmbedding |
| 65 | + SELECT |
| 66 | + movieId, |
| 67 | + cosineDistance(embedding, targetEmbedding) AS distance |
| 68 | + FROM embeddings |
| 69 | + WHERE movieId != theMovieId |
| 70 | + ORDER BY distance ASC |
| 71 | + LIMIT 5 |
| 72 | + """ |
| 73 | + ) |
| 74 | + print( |
| 75 | + f"Scaned {topN.rows_read()} rows, " |
| 76 | + f"Top 5 similar movies to movieId 100 in {topN.elapsed()}" |
| 77 | + ) |
| 78 | + print(topN) |
76 | 79 |
|
77 |
| -# def test_03_close_session(self): |
78 |
| -# chs.close() |
79 |
| -# self.assertEqual(chs._conn, None) |
| 80 | + def test_03_close_session(self): |
| 81 | + global chs |
| 82 | + chs.close() |
| 83 | + self.assertEqual(chs._conn, None) |
80 | 84 |
|
81 | 85 |
|
82 |
| -# if __name__ == "__main__": |
83 |
| -# unittest.main() |
| 86 | +if __name__ == "__main__": |
| 87 | + unittest.main() |
0 commit comments