Skip to content

Commit f8eb390

Browse files
authored
Merge pull request #46 from thedadams/list-models-from-providers
feat: add ability to list models from other providers
2 parents 1ff82ee + 583aad0 commit f8eb390

File tree

3 files changed

+103
-18
lines changed

3 files changed

+103
-18
lines changed

gptscript/gptscript.py

+30-10
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,22 @@ class GPTScript:
2626
def __init__(self, opts: GlobalOptions = None):
2727
if opts is None:
2828
opts = GlobalOptions()
29+
self.opts = opts
30+
2931
GPTScript.__gptscript_count += 1
3032

3133
if GPTScript.__server_url == "":
3234
GPTScript.__server_url = os.environ.get("GPTSCRIPT_URL", "127.0.0.1:0")
3335

3436
if GPTScript.__gptscript_count == 1 and os.environ.get("GPTSCRIPT_DISABLE_SERVER", "") != "true":
35-
opts.toEnv()
37+
self.opts.toEnv()
3638

3739
GPTScript.__process = Popen(
3840
[_get_command(), "--listen-address", GPTScript.__server_url, "sdkserver"],
3941
stdin=PIPE,
4042
stdout=PIPE,
4143
stderr=PIPE,
42-
env=opts.Env,
44+
env=self.opts.Env,
4345
text=True,
4446
encoding="utf-8",
4547
)
@@ -81,18 +83,28 @@ def evaluate(
8183
opts: Options = None,
8284
event_handlers: list[Callable[[Run, CallFrame | RunFrame | PromptFrame], Awaitable[None]]] = None
8385
) -> Run:
84-
return Run("evaluate", tool, opts, self._server_url, event_handlers=event_handlers).next_chat(
85-
"" if opts is None else opts.input
86-
)
86+
opts = opts if opts is not None else Options()
87+
return Run(
88+
"evaluate",
89+
tool,
90+
opts.merge_global_opts(self.opts),
91+
self._server_url,
92+
event_handlers=event_handlers,
93+
).next_chat("" if opts is None else opts.input)
8794

8895
def run(
8996
self, tool_path: str,
9097
opts: Options = None,
9198
event_handlers: list[Callable[[Run, CallFrame | RunFrame | PromptFrame], Awaitable[None]]] = None
9299
) -> Run:
93-
return Run("run", tool_path, opts, self._server_url, event_handlers=event_handlers).next_chat(
94-
"" if opts is None else opts.input
95-
)
100+
opts = opts if opts is not None else Options()
101+
return Run(
102+
"run",
103+
tool_path,
104+
opts.merge_global_opts(self.opts),
105+
self._server_url,
106+
event_handlers=event_handlers,
107+
).next_chat("" if opts is None else opts.input)
96108

97109
async def parse(self, file_path: str, disable_cache: bool = False) -> list[Text | Tool]:
98110
out = await self._run_basic_command("parse", {"file": file_path, "disableCache": disable_cache})
@@ -139,8 +151,16 @@ async def version(self) -> str:
139151
async def list_tools(self) -> str:
140152
return await self._run_basic_command("list-tools")
141153

142-
async def list_models(self) -> list[str]:
143-
return (await self._run_basic_command("list-models")).split("\n")
154+
async def list_models(self, providers: list[str] = None, credential_overrides: list[str] = None) -> list[str]:
155+
if self.opts.DefaultModelProvider != "":
156+
if providers is None:
157+
providers = []
158+
providers.append(self.opts.DefaultModelProvider)
159+
160+
return (await self._run_basic_command(
161+
"list-models",
162+
{"providers": providers, "credentialOverrides": credential_overrides}
163+
)).split("\n")
144164

145165

146166
def _get_command():

gptscript/opts.py

+37-4
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,33 @@
11
import os
2-
from typing import Mapping
2+
from typing import Mapping, Self
33

44

55
class GlobalOptions:
6-
def __init__(self,
7-
apiKey: str = "", baseURL: str = "", defaultModelProvider: str = "", defaultModel: str = "",
8-
env: Mapping[str, str] = None):
6+
def __init__(
7+
self,
8+
apiKey: str = "",
9+
baseURL: str = "",
10+
defaultModelProvider: str = "",
11+
defaultModel: str = "",
12+
env: Mapping[str, str] = None,
13+
):
914
self.APIKey = apiKey
1015
self.BaseURL = baseURL
1116
self.DefaultModel = defaultModel
1217
self.DefaultModelProvider = defaultModelProvider
1318
self.Env = env
1419

20+
def merge(self, other: Self) -> Self:
21+
cp = self.__class__()
22+
if other is None:
23+
return cp
24+
cp.APIKey = other.APIKey if other.APIKey != "" else self.APIKey
25+
cp.BaseURL = other.BaseURL if other.BaseURL != "" else self.BaseURL
26+
cp.DefaultModel = other.DefaultModel if other.DefaultModel != "" else self.DefaultModel
27+
cp.DefaultModelProvider = other.DefaultModelProvider if other.DefaultModelProvider != "" else self.DefaultModelProvider
28+
cp.Env = (other.Env or []).extend(self.Env or [])
29+
return cp
30+
1531
def toEnv(self):
1632
if self.Env is None:
1733
self.Env = os.environ.copy()
@@ -56,3 +72,20 @@ def __init__(self,
5672
self.location = location
5773
self.env = env
5874
self.forceSequential = forceSequential
75+
76+
def merge_global_opts(self, other: GlobalOptions) -> Self:
77+
cp = super().merge(other)
78+
if other is None:
79+
return cp
80+
cp.input = self.input
81+
cp.disableCache = self.disableCache
82+
cp.subTool = self.subTool
83+
cp.workspace = self.workspace
84+
cp.chatState = self.chatState
85+
cp.confirm = self.confirm
86+
cp.prompt = self.prompt
87+
cp.credentialOverrides = self.credentialOverrides
88+
cp.location = self.location
89+
cp.env = self.env
90+
cp.forceSequential = self.forceSequential
91+
return cp

tests/test_gptscript.py

+36-4
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,13 @@ def gptscript():
2525
if os.getenv("OPENAI_API_KEY") is None:
2626
pytest.fail("OPENAI_API_KEY not set", pytrace=False)
2727
try:
28+
# Start an initial GPTScript instance.
29+
# This one doesn't have any options, but it's there to ensure that using another instance works as expected in all cases.
30+
g_first = GPTScript()
2831
gptscript = GPTScript(GlobalOptions(apiKey=os.getenv("OPENAI_API_KEY")))
2932
yield gptscript
3033
gptscript.close()
34+
g_first.close()
3135
except Exception as e:
3236
pytest.fail(e, pytrace=False)
3337

@@ -111,6 +115,33 @@ async def test_list_models(gptscript):
111115
assert isinstance(models, list) and len(models) > 1, "Expected list_models to return a list"
112116

113117

118+
@pytest.mark.asyncio
119+
async def test_list_models_from_provider(gptscript):
120+
models = await gptscript.list_models(
121+
providers=["github.com/gptscript-ai/claude3-anthropic-provider"],
122+
credential_overrides=["github.com/gptscript-ai/claude3-anthropic-provider/credential:ANTHROPIC_API_KEY"],
123+
)
124+
assert isinstance(models, list) and len(models) > 1, "Expected list_models to return a list"
125+
for model in models:
126+
assert model.startswith("claude-3-"), "Unexpected model name"
127+
assert model.endswith("from github.com/gptscript-ai/claude3-anthropic-provider"), "Unexpected model name"
128+
129+
130+
@pytest.mark.asyncio
131+
async def test_list_models_from_default_provider():
132+
g = GPTScript(GlobalOptions(defaultModelProvider="github.com/gptscript-ai/claude3-anthropic-provider"))
133+
try:
134+
models = await g.list_models(
135+
credential_overrides=["github.com/gptscript-ai/claude3-anthropic-provider/credential:ANTHROPIC_API_KEY"],
136+
)
137+
assert isinstance(models, list) and len(models) > 1, "Expected list_models to return a list"
138+
for model in models:
139+
assert model.startswith("claude-3-"), "Unexpected model name"
140+
assert model.endswith("from github.com/gptscript-ai/claude3-anthropic-provider"), "Unexpected model name"
141+
finally:
142+
g.close()
143+
144+
114145
@pytest.mark.asyncio
115146
async def test_list_tools(gptscript):
116147
out = await gptscript.list_tools()
@@ -472,10 +503,11 @@ async def process_event(r: Run, frame: CallFrame | RunFrame | PromptFrame):
472503
event_content += output.content
473504

474505
tool = ToolDef(tools=["sys.exec"], instructions="List the files in the current directory as '.'.")
475-
out = await gptscript.evaluate(tool,
476-
Options(confirm=True, disableCache=True),
477-
event_handlers=[process_event],
478-
).text()
506+
out = await gptscript.evaluate(
507+
tool,
508+
Options(confirm=True, disableCache=True),
509+
event_handlers=[process_event],
510+
).text()
479511

480512
assert confirm_event_found, "No confirm event"
481513
# Running the `dir` command in Windows will give the contents of the tests directory

0 commit comments

Comments
 (0)