3
3
4
4
from openai import AsyncAzureOpenAI , AsyncOpenAI
5
5
from openai .types .chat import ChatCompletionMessageParam
6
- from pydantic_ai import Agent , RunContext
6
+ from pydantic_ai import Agent
7
7
from pydantic_ai .messages import ModelMessagesTypeAdapter
8
8
from pydantic_ai .models .openai import OpenAIModel
9
9
from pydantic_ai .providers .openai import OpenAIProvider
10
10
from pydantic_ai .settings import ModelSettings
11
11
12
12
from fastapi_app .api_models import (
13
13
AIChatRoles ,
14
- BrandFilter ,
15
14
ChatRequestOverrides ,
16
15
Filter ,
17
16
ItemPublic ,
18
17
Message ,
19
- PriceFilter ,
20
18
RAGContext ,
21
19
RetrievalResponse ,
22
20
RetrievalResponseDelta ,
21
+ SearchArguments ,
23
22
SearchResults ,
24
23
ThoughtStep ,
25
24
)
@@ -59,7 +58,7 @@ def __init__(
59
58
),
60
59
system_prompt = self .query_prompt_template ,
61
60
tools = [self .search_database ],
62
- output_type = SearchResults ,
61
+ output_type = SearchArguments ,
63
62
)
64
63
self .answer_agent = Agent (
65
64
pydantic_chat_model ,
@@ -73,10 +72,7 @@ def __init__(
73
72
74
73
async def search_database (
75
74
self ,
76
- ctx : RunContext [ChatParams ],
77
- search_query : str ,
78
- price_filter : Optional [PriceFilter ] = None ,
79
- brand_filter : Optional [BrandFilter ] = None ,
75
+ search_arguments : SearchArguments ,
80
76
) -> SearchResults :
81
77
"""
82
78
Search PostgreSQL database for relevant products based on user query
@@ -91,52 +87,55 @@ async def search_database(
91
87
"""
92
88
# Only send non-None filters
93
89
filters : list [Filter ] = []
94
- if price_filter :
95
- filters .append (price_filter )
96
- if brand_filter :
97
- filters .append (brand_filter )
90
+ if search_arguments . price_filter :
91
+ filters .append (search_arguments . price_filter )
92
+ if search_arguments . brand_filter :
93
+ filters .append (search_arguments . brand_filter )
98
94
results = await self .searcher .search_and_embed (
99
- search_query ,
100
- top = ctx . deps .top ,
101
- enable_vector_search = ctx . deps .enable_vector_search ,
102
- enable_text_search = ctx . deps .enable_text_search ,
95
+ search_arguments . search_query ,
96
+ top = self . chat_params .top ,
97
+ enable_vector_search = self . chat_params .enable_vector_search ,
98
+ enable_text_search = self . chat_params .enable_text_search ,
103
99
filters = filters ,
104
100
)
105
101
return SearchResults (
106
- query = search_query , items = [ItemPublic .model_validate (item .to_dict ()) for item in results ], filters = filters
102
+ query = search_arguments .search_query ,
103
+ items = [ItemPublic .model_validate (item .to_dict ()) for item in results ],
104
+ filters = filters ,
107
105
)
108
106
109
107
async def prepare_context (self ) -> tuple [list [ItemPublic ], list [ThoughtStep ]]:
110
108
few_shots = ModelMessagesTypeAdapter .validate_json (self .query_fewshots )
111
109
user_query = f"Find search results for user query: { self .chat_params .original_user_query } "
112
- results = await self .search_agent .run (
110
+ search_agent_runner = await self .search_agent .run (
113
111
user_query ,
114
112
message_history = few_shots + self .chat_params .past_messages ,
115
- deps = self . chat_params ,
113
+ output_type = SearchArguments ,
116
114
)
117
- items = results .output .items
115
+ search_arguments = search_agent_runner .output
116
+ search_results = await self .search_database (search_arguments = search_arguments )
118
117
thoughts = [
119
118
ThoughtStep (
120
119
title = "Prompt to generate search arguments" ,
121
- description = results .all_messages (),
120
+ description = search_agent_runner .all_messages (),
122
121
props = self .model_for_thoughts ,
123
122
),
124
123
ThoughtStep (
125
124
title = "Search using generated search arguments" ,
126
- description = results . output .query ,
125
+ description = search_results .query ,
127
126
props = {
128
127
"top" : self .chat_params .top ,
129
128
"vector_search" : self .chat_params .enable_vector_search ,
130
129
"text_search" : self .chat_params .enable_text_search ,
131
- "filters" : results . output .filters ,
130
+ "filters" : search_results .filters ,
132
131
},
133
132
),
134
133
ThoughtStep (
135
134
title = "Search results" ,
136
- description = items ,
135
+ description = search_results . items ,
137
136
),
138
137
]
139
- return items , thoughts
138
+ return search_results . items , thoughts
140
139
141
140
async def answer (
142
141
self ,
0 commit comments