Skip to content

Commit 368e7ea

Browse files
authored
Merge pull request #77 from shroominic/chat_history_backends
Chat history backends + Session ID management
2 parents e10cdf4 + 9cc650c commit 368e7ea

File tree

8 files changed

+320
-112
lines changed

8 files changed

+320
-112
lines changed

.env.example

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
1+
# (set True to enable logging)
2+
VERBOSE=False
3+
14
# (required)
25
OPENAI_API_KEY=
6+
# ANTHROPIC_API_KEY=
7+
38
# (optional, required for production)
49
# CODEBOX_API_KEY=
5-
# (set True to enable logging)
6-
VERBOSE=False
10+
711
# (optional, required for Azure OpenAI)
812
# OPENAI_API_TYPE=azure
913
# OPENAI_API_VERSION=2023-07-01-preview
1014
# OPENAI_API_BASE=
1115
# DEPLOYMENT_NAME=
16+
17+
# (optional, [codebox, postgres or redis])
18+
# HISTORY_BACKEND=postgres
19+
# REDIS_URL=
20+
# POSTGRES_URL=

codeinterpreterapi/chat_history.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import asyncio
2+
import json
3+
from typing import List
4+
5+
from codeboxapi import CodeBox # type: ignore
6+
from langchain.schema import BaseChatMessageHistory
7+
from langchain.schema.messages import BaseMessage, messages_from_dict, messages_to_dict
8+
9+
10+
# TODO: This is probably not efficient, but it works for now.
11+
class CodeBoxChatMessageHistory(BaseChatMessageHistory):
12+
"""
13+
Chat message history that stores history inside the codebox.
14+
"""
15+
16+
def __init__(self, codebox: CodeBox):
17+
self.codebox = codebox
18+
19+
if "history.json" not in [f.name for f in self.codebox.list_files()]:
20+
name, content = "history.json", b"{}"
21+
if (loop := asyncio.get_event_loop()).is_running():
22+
loop.create_task(self.codebox.aupload(name, content))
23+
else:
24+
self.codebox.upload(name, content)
25+
26+
@property
27+
def messages(self) -> List[BaseMessage]: # type: ignore
28+
"""Retrieve the messages from the codebox"""
29+
msgs = (
30+
messages_from_dict(json.loads(file_content.decode("utf-8")))
31+
if (
32+
file_content := (
33+
loop.run_until_complete(self.codebox.adownload("history.json"))
34+
if (loop := asyncio.get_event_loop()).is_running()
35+
else self.codebox.download("history.json")
36+
).content
37+
)
38+
else []
39+
)
40+
return msgs
41+
42+
def add_message(self, message: BaseMessage) -> None:
43+
"""Append the message to the record in the local file"""
44+
print("Current messages: ", self.messages)
45+
messages = messages_to_dict(self.messages)
46+
print("Adding message: ", message)
47+
messages.append(messages_to_dict([message])[0])
48+
name, content = "history.json", json.dumps(messages).encode("utf-8")
49+
if (loop := asyncio.get_event_loop()).is_running():
50+
loop.create_task(self.codebox.aupload(name, content))
51+
else:
52+
self.codebox.upload(name, content)
53+
print("New messages: ", self.messages)
54+
55+
def clear(self) -> None:
56+
"""Clear session memory from the local file"""
57+
print("Clearing history CLEARING HISTORY")
58+
code = "import os; os.remove('history.json')"
59+
if (loop := asyncio.get_event_loop()).is_running():
60+
loop.create_task(self.codebox.arun(code))
61+
else:
62+
self.codebox.run(code)

codeinterpreterapi/config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,12 @@ class CodeInterpreterAPISettings(BaseSettings):
1414

1515
VERBOSE: bool = False
1616

17-
CODEBOX_API_KEY: Optional[str] = None
1817
OPENAI_API_KEY: Optional[str] = None
18+
CODEBOX_API_KEY: Optional[str] = None
19+
20+
HISTORY_BACKEND: Optional[str] = None
21+
REDIS_URL: str = "redis://localhost:6379"
22+
POSTGRES_URL: str = "postgresql://postgres:postgres@localhost:5432/postgres"
1923

2024

2125
settings = CodeInterpreterAPISettings()

codeinterpreterapi/schema/file.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ class File(BaseModel):
99

1010
@classmethod
1111
def from_path(cls, path: str):
12+
if not path.startswith("/"):
13+
path = f"./{path}"
1214
with open(path, "rb") as f:
1315
path = path.split("/")[-1]
1416
return cls(name=path, content=f.read())
@@ -33,6 +35,8 @@ async def afrom_url(cls, url: str):
3335
return cls(name=url.split("/")[-1], content=await r.read())
3436

3537
def save(self, path: str):
38+
if not path.startswith("/"):
39+
path = f"./{path}"
3640
with open(path, "wb") as f:
3741
f.write(self.content)
3842

codeinterpreterapi/session.py

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import base64
22
import re
33
import traceback
4-
import uuid
54
from io import BytesIO
65
from os import getenv
76
from typing import Optional
7+
from uuid import UUID, uuid4
88

99
from codeboxapi import CodeBox # type: ignore
1010
from codeboxapi.schema import CodeBoxOutput # type: ignore
@@ -17,8 +17,13 @@
1717
from langchain.chat_models import AzureChatOpenAI, ChatAnthropic, ChatOpenAI
1818
from langchain.chat_models.base import BaseChatModel
1919
from langchain.memory import ConversationBufferMemory
20+
from langchain.memory.chat_message_histories import (
21+
ChatMessageHistory,
22+
PostgresChatMessageHistory,
23+
RedisChatMessageHistory,
24+
)
2025
from langchain.prompts.chat import MessagesPlaceholder
21-
from langchain.schema.language_model import BaseLanguageModel
26+
from langchain.schema import BaseChatMessageHistory, BaseLanguageModel
2227
from langchain.tools import BaseTool, StructuredTool
2328

2429
from codeinterpreterapi.agents import OpenAIFunctionsAgent
@@ -28,6 +33,7 @@
2833
get_file_modifications,
2934
remove_download_link,
3035
)
36+
from codeinterpreterapi.chat_history import CodeBoxChatMessageHistory
3137
from codeinterpreterapi.config import settings
3238
from codeinterpreterapi.parser import CodeAgentOutputParser, CodeChatAgentOutputParser
3339
from codeinterpreterapi.prompts import code_interpreter_system_message
@@ -47,20 +53,35 @@ def __init__(
4753
additional_tools: list[BaseTool] = [],
4854
**kwargs,
4955
) -> None:
50-
self.codebox = CodeBox(**kwargs)
56+
self.codebox = CodeBox()
5157
self.verbose = kwargs.get("verbose", settings.VERBOSE)
5258
self.tools: list[BaseTool] = self._tools(additional_tools)
5359
self.llm: BaseLanguageModel = llm or self._choose_llm(**kwargs)
54-
self.agent_executor: AgentExecutor = self._agent_executor()
60+
self.agent_executor: Optional[AgentExecutor] = None
5561
self.input_files: list[File] = []
5662
self.output_files: list[File] = []
5763
self.code_log: list[tuple[str, str]] = []
5864

65+
@classmethod
66+
def from_id(cls, session_id: UUID, **kwargs) -> "CodeInterpreterSession":
67+
session = cls(**kwargs)
68+
session.codebox = CodeBox.from_id(session_id)
69+
session.agent_executor = session._agent_executor()
70+
return session
71+
72+
@property
73+
def session_id(self) -> Optional[UUID]:
74+
return self.codebox.session_id
75+
5976
def start(self) -> SessionStatus:
60-
return SessionStatus.from_codebox_status(self.codebox.start())
77+
status = SessionStatus.from_codebox_status(self.codebox.start())
78+
self.agent_executor = self._agent_executor()
79+
return status
6180

6281
async def astart(self) -> SessionStatus:
63-
return SessionStatus.from_codebox_status(await self.codebox.astart())
82+
status = SessionStatus.from_codebox_status(await self.codebox.astart())
83+
self.agent_executor = self._agent_executor()
84+
return status
6485

6586
def _tools(self, additional_tools: list[BaseTool]) -> list[BaseTool]:
6687
return additional_tools + [
@@ -108,15 +129,15 @@ def _choose_llm(
108129
openai_api_key=openai_api_key,
109130
max_retries=3,
110131
request_timeout=60 * 3,
111-
)
132+
) # type: ignore
112133
else:
113134
return ChatOpenAI(
114135
temperature=0.03,
115136
model=model,
116137
openai_api_key=openai_api_key,
117138
max_retries=3,
118139
request_timeout=60 * 3,
119-
)
140+
) # type: ignore
120141
elif "claude" in model:
121142
return ChatAnthropic(model=model)
122143
else:
@@ -148,14 +169,33 @@ def _choose_agent(self) -> BaseSingleActionAgent:
148169
)
149170
)
150171

172+
def _history_backend(self) -> BaseChatMessageHistory:
173+
return (
174+
CodeBoxChatMessageHistory(codebox=self.codebox)
175+
if settings.HISTORY_BACKEND == "codebox"
176+
else RedisChatMessageHistory(
177+
session_id=str(self.session_id),
178+
url=settings.REDIS_URL,
179+
)
180+
if settings.HISTORY_BACKEND == "redis"
181+
else PostgresChatMessageHistory(
182+
session_id=str(self.session_id),
183+
connection_string=settings.POSTGRES_URL,
184+
)
185+
if settings.HISTORY_BACKEND == "postgres"
186+
else ChatMessageHistory()
187+
)
188+
151189
def _agent_executor(self) -> AgentExecutor:
152190
return AgentExecutor.from_agent_and_tools(
153191
agent=self._choose_agent(),
154192
max_iterations=9,
155193
tools=self.tools,
156194
verbose=self.verbose,
157195
memory=ConversationBufferMemory(
158-
memory_key="chat_history", return_messages=True
196+
memory_key="chat_history",
197+
return_messages=True,
198+
chat_memory=self._history_backend(),
159199
),
160200
)
161201

@@ -178,7 +218,7 @@ def _run_handler(self, code: str):
178218
raise TypeError("Expected output.content to be a string.")
179219

180220
if output.type == "image/png":
181-
filename = f"image-{uuid.uuid4()}.png"
221+
filename = f"image-{uuid4()}.png"
182222
file_buffer = BytesIO(base64.b64decode(output.content))
183223
file_buffer.name = filename
184224
self.output_files.append(File(name=filename, content=file_buffer.read()))
@@ -225,7 +265,7 @@ async def _arun_handler(self, code: str):
225265
raise TypeError("Expected output.content to be a string.")
226266

227267
if output.type == "image/png":
228-
filename = f"image-{uuid.uuid4()}.png"
268+
filename = f"image-{uuid4()}.png"
229269
file_buffer = BytesIO(base64.b64decode(output.content))
230270
file_buffer.name = filename
231271
self.output_files.append(File(name=filename, content=file_buffer.read()))
@@ -349,6 +389,7 @@ def generate_response_sync(
349389
user_request = UserRequest(content=user_msg, files=files)
350390
try:
351391
self._input_handler(user_request)
392+
assert self.agent_executor, "Session not initialized."
352393
response = self.agent_executor.run(input=user_request.content)
353394
return self._output_handler(response)
354395
except Exception as e:
@@ -392,6 +433,7 @@ async def agenerate_response(
392433
user_request = UserRequest(content=user_msg, files=files)
393434
try:
394435
await self._ainput_handler(user_request)
436+
assert self.agent_executor, "Session not initialized."
395437
response = await self.agent_executor.arun(input=user_request.content)
396438
return await self._aoutput_handler(response)
397439
except Exception as e:

examples/chat_history_backend.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import os
2+
3+
os.environ["HISTORY_BACKEND"] = "redis"
4+
os.environ["REDIS_HOST"] = "redis://localhost:6379"
5+
6+
from codeinterpreterapi import CodeInterpreterSession # noqa: E402
7+
8+
9+
def main():
10+
session_id = None
11+
12+
session = CodeInterpreterSession()
13+
session.start()
14+
15+
print("Session ID:", session.session_id)
16+
session_id = session.session_id
17+
18+
response = session.generate_response_sync("Plot the bitcoin chart of 2023 YTD")
19+
response.show()
20+
21+
del session
22+
23+
assert session_id is not None
24+
session = CodeInterpreterSession.from_id(session_id)
25+
26+
response = session.generate_response_sync("Now for the last 5 years")
27+
response.show()
28+
29+
session.stop()
30+
31+
32+
if __name__ == "__main__":
33+
main()

0 commit comments

Comments
 (0)