Skip to content

Commit 635caf2

Browse files
committed
🔨parser fix typing
1 parent 9cbfd25 commit 635caf2

File tree

3 files changed

+17
-7
lines changed

3 files changed

+17
-7
lines changed

codeinterpreterapi/chains/extract_code.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,20 @@
22
from langchain.chat_models.anthropic import ChatAnthropic
33

44

5-
# TODO: make async
65
def extract_python_code(
76
text: str,
87
llm: BaseLanguageModel,
98
retry: int = 2,
109
):
11-
pass
10+
pass # TODO
11+
12+
13+
async def aextract_python_code(
14+
text: str,
15+
llm: BaseLanguageModel,
16+
retry: int = 2,
17+
):
18+
pass # TODO
1219

1320

1421
async def test():

codeinterpreterapi/parser.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ def _type(self) -> str:
3838

3939

4040
class CodeChatAgentOutputParser(AgentOutputParser):
41+
def __init__(self, llm: BaseChatModel, **kwargs):
42+
super().__init__(**kwargs)
43+
self.llm = llm
44+
4145
def get_format_instructions(self) -> str:
4246
from langchain.agents.conversational_chat.prompt import FORMAT_INSTRUCTIONS
4347

@@ -46,9 +50,7 @@ def get_format_instructions(self) -> str:
4650
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
4751
raise NotImplementedError
4852

49-
async def aparse(
50-
self, text: str, llm: BaseChatModel
51-
) -> Union[AgentAction, AgentFinish]:
53+
async def aparse(self, text: str) -> Union[AgentAction, AgentFinish]:
5254
try:
5355
response = parse_json_markdown(text)
5456
action, action_input = response["action"], response["action_input"]
@@ -58,8 +60,9 @@ async def aparse(
5860
return AgentAction(action, action_input, text)
5961
except Exception:
6062
if '"action": "python"' in text:
63+
print("TODO: Not implemented")
6164
# extract python code from text with prompt
62-
text = extract_python_code(text, llm=llm) or ""
65+
text = extract_python_code(text, llm=self.llm) or ""
6366
match = re.search(r"```python\n(.*?)```", text)
6467
if match:
6568
code = match.group(1).replace("\\n", "; ")

codeinterpreterapi/session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def _choose_agent(self) -> BaseSingleActionAgent:
163163
llm=self.llm,
164164
tools=self.tools,
165165
system_message=code_interpreter_system_message.content,
166-
output_parser=CodeChatAgentOutputParser(),
166+
output_parser=CodeChatAgentOutputParser(self.llm),
167167
)
168168
if isinstance(self.llm, BaseChatModel)
169169
else ConversationalAgent.from_llm_and_tools(

0 commit comments

Comments
 (0)