Skip to content

Commit fcc9dd2

Browse files
committed
Change to output type
1 parent 7e744d2 commit fcc9dd2

File tree

3 files changed

+31
-26
lines changed

3 files changed

+31
-26
lines changed

src/backend/fastapi_app/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def create_app(testing: bool = False):
5858
else:
5959
if not testing:
6060
load_dotenv(override=True)
61-
logging.basicConfig(level=logging.INFO)
61+
logging.basicConfig(level=logging.DEBUG)
6262

6363
# Turn off particularly noisy INFO level logs from Azure Core SDK:
6464
logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(logging.WARNING)

src/backend/fastapi_app/api_models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,12 @@ class BrandFilter(Filter):
117117
value: str = Field(description="The brand name to compare against (e.g., 'AirStrider')")
118118

119119

120+
class SearchArguments(BaseModel):
121+
search_query: str
122+
price_filter: Optional[PriceFilter] = Field(default=None)
123+
brand_filter: Optional[BrandFilter] = Field(default=None)
124+
125+
120126
class SearchResults(BaseModel):
121127
query: str
122128
"""The original search query"""

src/backend/fastapi_app/rag_advanced.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,22 @@
33

44
from openai import AsyncAzureOpenAI, AsyncOpenAI
55
from openai.types.chat import ChatCompletionMessageParam
6-
from pydantic_ai import Agent, RunContext
6+
from pydantic_ai import Agent
77
from pydantic_ai.messages import ModelMessagesTypeAdapter
88
from pydantic_ai.models.openai import OpenAIModel
99
from pydantic_ai.providers.openai import OpenAIProvider
1010
from pydantic_ai.settings import ModelSettings
1111

1212
from fastapi_app.api_models import (
1313
AIChatRoles,
14-
BrandFilter,
1514
ChatRequestOverrides,
1615
Filter,
1716
ItemPublic,
1817
Message,
19-
PriceFilter,
2018
RAGContext,
2119
RetrievalResponse,
2220
RetrievalResponseDelta,
21+
SearchArguments,
2322
SearchResults,
2423
ThoughtStep,
2524
)
@@ -59,7 +58,7 @@ def __init__(
5958
),
6059
system_prompt=self.query_prompt_template,
6160
tools=[self.search_database],
62-
output_type=SearchResults,
61+
output_type=SearchArguments,
6362
)
6463
self.answer_agent = Agent(
6564
pydantic_chat_model,
@@ -73,10 +72,7 @@ def __init__(
7372

7473
async def search_database(
7574
self,
76-
ctx: RunContext[ChatParams],
77-
search_query: str,
78-
price_filter: Optional[PriceFilter] = None,
79-
brand_filter: Optional[BrandFilter] = None,
75+
search_arguments: SearchArguments,
8076
) -> SearchResults:
8177
"""
8278
Search PostgreSQL database for relevant products based on user query
@@ -91,52 +87,55 @@ async def search_database(
9187
"""
9288
# Only send non-None filters
9389
filters: list[Filter] = []
94-
if price_filter:
95-
filters.append(price_filter)
96-
if brand_filter:
97-
filters.append(brand_filter)
90+
if search_arguments.price_filter:
91+
filters.append(search_arguments.price_filter)
92+
if search_arguments.brand_filter:
93+
filters.append(search_arguments.brand_filter)
9894
results = await self.searcher.search_and_embed(
99-
search_query,
100-
top=ctx.deps.top,
101-
enable_vector_search=ctx.deps.enable_vector_search,
102-
enable_text_search=ctx.deps.enable_text_search,
95+
search_arguments.search_query,
96+
top=self.chat_params.top,
97+
enable_vector_search=self.chat_params.enable_vector_search,
98+
enable_text_search=self.chat_params.enable_text_search,
10399
filters=filters,
104100
)
105101
return SearchResults(
106-
query=search_query, items=[ItemPublic.model_validate(item.to_dict()) for item in results], filters=filters
102+
query=search_arguments.search_query,
103+
items=[ItemPublic.model_validate(item.to_dict()) for item in results],
104+
filters=filters,
107105
)
108106

109107
async def prepare_context(self) -> tuple[list[ItemPublic], list[ThoughtStep]]:
110108
few_shots = ModelMessagesTypeAdapter.validate_json(self.query_fewshots)
111109
user_query = f"Find search results for user query: {self.chat_params.original_user_query}"
112-
results = await self.search_agent.run(
110+
search_agent_runner = await self.search_agent.run(
113111
user_query,
114112
message_history=few_shots + self.chat_params.past_messages,
115-
deps=self.chat_params,
113+
output_type=SearchArguments,
116114
)
117-
items = results.output.items
115+
search_arguments = search_agent_runner.output
116+
search_results = await self.search_database(search_arguments=search_arguments)
118117
thoughts = [
119118
ThoughtStep(
120119
title="Prompt to generate search arguments",
121-
description=results.all_messages(),
120+
description=search_agent_runner.all_messages(),
122121
props=self.model_for_thoughts,
123122
),
124123
ThoughtStep(
125124
title="Search using generated search arguments",
126-
description=results.output.query,
125+
description=search_results.query,
127126
props={
128127
"top": self.chat_params.top,
129128
"vector_search": self.chat_params.enable_vector_search,
130129
"text_search": self.chat_params.enable_text_search,
131-
"filters": results.output.filters,
130+
"filters": search_results.filters,
132131
},
133132
),
134133
ThoughtStep(
135134
title="Search results",
136-
description=items,
135+
description=search_results.items,
137136
),
138137
]
139-
return items, thoughts
138+
return search_results.items, thoughts
140139

141140
async def answer(
142141
self,

0 commit comments

Comments
 (0)