Skip to content

Commit d1c6507

Browse files
authored
Merge pull request #338 from grafana/backend-model-mapping
2 parents eb3e954 + fb6acd0 commit d1c6507

File tree

7 files changed

+213
-29
lines changed

7 files changed

+213
-29
lines changed

packages/grafana-llm-app/pkg/plugin/app.go

+7-2
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,20 @@ func NewApp(ctx context.Context, appSettings backend.AppInstanceSettings) (insta
5353
return nil, err
5454
}
5555

56+
if app.settings.Models == nil {
57+
// backwards-compat: if Model settings is nil, use the default one
58+
app.settings.Models = DEFAULT_MODEL_SETTINGS
59+
}
60+
5661
switch app.settings.OpenAI.Provider {
5762
case openAIProviderOpenAI:
58-
p, err := NewOpenAIProvider(app.settings.OpenAI)
63+
p, err := NewOpenAIProvider(app.settings.OpenAI, app.settings.Models)
5964
if err != nil {
6065
return nil, err
6166
}
6267
app.llmProvider = p
6368
case openAIProviderAzure:
64-
p, err := NewAzureProvider(app.settings.OpenAI)
69+
p, err := NewAzureProvider(app.settings.OpenAI, app.settings.Models.Default)
6570
if err != nil {
6671
return nil, err
6772
}

packages/grafana-llm-app/pkg/plugin/azure_provider.go

+22-13
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,18 @@ import (
1111
)
1212

1313
type azure struct {
14-
settings OpenAISettings
15-
oc *openai.Client
14+
settings OpenAISettings
15+
defaultModel Model
16+
oc *openai.Client
1617
}
1718

18-
func NewAzureProvider(settings OpenAISettings) (LLMProvider, error) {
19+
func NewAzureProvider(settings OpenAISettings, defaultModel Model) (LLMProvider, error) {
1920
client := &http.Client{
2021
Timeout: 2 * time.Minute,
2122
}
2223
p := &azure{
23-
settings: settings,
24+
settings: settings,
25+
defaultModel: defaultModel,
2426
}
2527

2628
// go-openai expects the URL without the '/openai' suffix, which is
@@ -48,14 +50,25 @@ func (p *azure) Models(ctx context.Context) (ModelResponse, error) {
4850
return ModelResponse{Data: models}, nil
4951
}
5052

51-
func (p *azure) ChatCompletion(ctx context.Context, req ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
53+
func (p *azure) getDeployment(model Model) (string, error) {
5254
mapping, err := p.getAzureMapping()
5355
if err != nil {
54-
return openai.ChatCompletionResponse{}, err
56+
return "", err
57+
}
58+
if model == "" {
59+
model = p.defaultModel
5560
}
56-
deployment := mapping[req.Model]
61+
deployment := mapping[model]
5762
if deployment == "" {
58-
return openai.ChatCompletionResponse{}, fmt.Errorf("%w: no deployment found for model: %s", errBadRequest, req.Model)
63+
return "", fmt.Errorf("%w: no deployment found for model: %s", errBadRequest, model)
64+
}
65+
return deployment, nil
66+
}
67+
68+
func (p *azure) ChatCompletion(ctx context.Context, req ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
69+
deployment, err := p.getDeployment(req.Model)
70+
if err != nil {
71+
return openai.ChatCompletionResponse{}, err
5972
}
6073

6174
r := req.ChatCompletionRequest
@@ -69,14 +82,10 @@ func (p *azure) ChatCompletion(ctx context.Context, req ChatCompletionRequest) (
6982
}
7083

7184
func (p *azure) ChatCompletionStream(ctx context.Context, req ChatCompletionRequest) (<-chan ChatCompletionStreamResponse, error) {
72-
mapping, err := p.getAzureMapping()
85+
deployment, err := p.getDeployment(req.Model)
7386
if err != nil {
7487
return nil, err
7588
}
76-
deployment := mapping[req.Model]
77-
if deployment == "" {
78-
return nil, fmt.Errorf("%w: no deployment found for model: %s", errBadRequest, req.Model)
79-
}
8089

8190
r := req.ChatCompletionRequest
8291
// For the Azure mapping we want to use the name of the mapped deployment as the model.

packages/grafana-llm-app/pkg/plugin/grafana_provider.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func (p *grafanaProvider) Models(ctx context.Context) (ModelResponse, error) {
5959

6060
func (p *grafanaProvider) ChatCompletion(ctx context.Context, req ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
6161
r := req.ChatCompletionRequest
62-
r.Model = req.Model.toOpenAI()
62+
r.Model = req.Model.toOpenAI(DEFAULT_MODEL_SETTINGS)
6363
resp, err := p.oc.CreateChatCompletion(ctx, r)
6464
if err != nil {
6565
log.DefaultLogger.Error("error creating grafana chat completion", "err", err)
@@ -70,6 +70,6 @@ func (p *grafanaProvider) ChatCompletion(ctx context.Context, req ChatCompletion
7070

7171
func (p *grafanaProvider) ChatCompletionStream(ctx context.Context, req ChatCompletionRequest) (<-chan ChatCompletionStreamResponse, error) {
7272
r := req.ChatCompletionRequest
73-
r.Model = req.Model.toOpenAI()
73+
r.Model = req.Model.toOpenAI(DEFAULT_MODEL_SETTINGS)
7474
return streamOpenAIRequest(ctx, r, p.oc)
7575
}

packages/grafana-llm-app/pkg/plugin/llm_provider.go

+10-8
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,17 @@ func (m *Model) UnmarshalJSON(data []byte) error {
4949
return fmt.Errorf("unrecognized model: %s", dataString)
5050
}
5151

52-
func (m Model) toOpenAI() string {
53-
// TODO: Add ability to change which model is used for each abstraction in settings.
54-
switch m {
55-
case ModelBase:
56-
return "gpt-3.5-turbo"
57-
case ModelLarge:
58-
return "gpt-4-turbo"
52+
func (m Model) toOpenAI(modelSettings *ModelSettings) string {
53+
if modelSettings == nil || len(modelSettings.Mapping) == 0 {
54+
switch m {
55+
case ModelBase:
56+
return "gpt-3.5-turbo"
57+
case ModelLarge:
58+
return "gpt-4-turbo"
59+
}
60+
panic(fmt.Sprintf("unrecognized model: %s", m))
5961
}
60-
panic("unknown model: " + m)
62+
return modelSettings.getModel(m)
6163
}
6264

6365
type ChatCompletionRequest struct {

packages/grafana-llm-app/pkg/plugin/openai_provider.go

+5-3
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@ import (
1515

1616
type openAI struct {
1717
settings OpenAISettings
18+
models *ModelSettings
1819
oc *openai.Client
1920
}
2021

21-
func NewOpenAIProvider(settings OpenAISettings) (LLMProvider, error) {
22+
func NewOpenAIProvider(settings OpenAISettings, models *ModelSettings) (LLMProvider, error) {
2223
client := &http.Client{
2324
Timeout: 2 * time.Minute,
2425
}
@@ -32,6 +33,7 @@ func NewOpenAIProvider(settings OpenAISettings) (LLMProvider, error) {
3233
cfg.OrgID = settings.OrganizationID
3334
return &openAI{
3435
settings: settings,
36+
models: models,
3537
oc: openai.NewClientWithConfig(cfg),
3638
}, nil
3739
}
@@ -53,7 +55,7 @@ type openAIChatCompletionRequest struct {
5355

5456
func (p *openAI) ChatCompletion(ctx context.Context, req ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
5557
r := req.ChatCompletionRequest
56-
r.Model = req.Model.toOpenAI()
58+
r.Model = req.Model.toOpenAI(p.models)
5759
resp, err := p.oc.CreateChatCompletion(ctx, r)
5860
if err != nil {
5961
log.DefaultLogger.Error("error creating openai chat completion", "err", err)
@@ -64,7 +66,7 @@ func (p *openAI) ChatCompletion(ctx context.Context, req ChatCompletionRequest)
6466

6567
func (p *openAI) ChatCompletionStream(ctx context.Context, req ChatCompletionRequest) (<-chan ChatCompletionStreamResponse, error) {
6668
r := req.ChatCompletionRequest
67-
r.Model = req.Model.toOpenAI()
69+
r.Model = req.Model.toOpenAI(p.models)
6870
return streamOpenAIRequest(ctx, r, p.oc)
6971
}
7072

packages/grafana-llm-app/pkg/plugin/resources_test.go

+134-1
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,30 @@ func TestCallOpenAIProxy(t *testing.T) {
237237

238238
expStatus: http.StatusOK,
239239
},
240+
{
241+
name: "openai - empty model",
242+
243+
settings: Settings{
244+
OpenAI: OpenAISettings{
245+
OrganizationID: "myOrg",
246+
Provider: openAIProviderOpenAI,
247+
},
248+
},
249+
apiKey: "abcd1234",
250+
251+
method: http.MethodPost,
252+
path: "/openai/v1/chat/completions",
253+
body: []byte(`{"messages": [{"content":"some stuff"}]}`),
254+
255+
expReqHeaders: http.Header{
256+
"Authorization": {"Bearer abcd1234"},
257+
"OpenAI-Organization": {"myOrg"},
258+
},
259+
expReqPath: "/v1/chat/completions",
260+
expReqBody: []byte(`{"model": "gpt-3.5-turbo", "messages": [{"content":"some stuff"}]}`),
261+
262+
expStatus: http.StatusOK,
263+
},
240264
{
241265
name: "openai - streaming",
242266

@@ -265,6 +289,34 @@ func TestCallOpenAIProxy(t *testing.T) {
265289
// newlines (required by the SSE spec) are escaped.
266290
expBody: []byte("data: {\"id\":\"\",\"object\":\"\",\"created\":0,\"model\":\"\",\"choices\":null,\"system_fingerprint\":\"\"}\n\ndata: [DONE]\n\n"),
267291
},
292+
{
293+
name: "openai - streaming - empty model",
294+
295+
settings: Settings{
296+
OpenAI: OpenAISettings{
297+
OrganizationID: "myOrg",
298+
Provider: openAIProviderOpenAI,
299+
},
300+
},
301+
apiKey: "abcd1234",
302+
303+
method: http.MethodPost,
304+
path: "/openai/v1/chat/completions",
305+
body: []byte(`{"stream": true, "messages": [{"content":"some stuff"}]}`),
306+
307+
expReqHeaders: http.Header{
308+
"Authorization": {"Bearer abcd1234"},
309+
"OpenAI-Organization": {"myOrg"},
310+
},
311+
expReqPath: "/v1/chat/completions",
312+
expReqBody: []byte(`{"model": "gpt-3.5-turbo", "stream": true, "messages": [{"content":"some stuff"}]}`),
313+
314+
expStatus: http.StatusOK,
315+
316+
// We need to use regular strings rather than raw strings here otherwise the double
317+
// newlines (required by the SSE spec) are escaped.
318+
expBody: []byte("data: {\"id\":\"\",\"object\":\"\",\"created\":0,\"model\":\"\",\"choices\":null,\"system_fingerprint\":\"\"}\n\ndata: [DONE]\n\n"),
319+
},
268320
{
269321
name: "azure",
270322

@@ -293,6 +345,62 @@ func TestCallOpenAIProxy(t *testing.T) {
293345

294346
expStatus: http.StatusOK,
295347
},
348+
{
349+
name: "azure - abstract model",
350+
351+
settings: Settings{
352+
OpenAI: OpenAISettings{
353+
OrganizationID: "myOrg",
354+
Provider: openAIProviderAzure,
355+
AzureMapping: [][]string{
356+
{"gpt-3.5-turbo", "gpt-35-turbo"},
357+
},
358+
},
359+
},
360+
361+
apiKey: "abcd1234",
362+
363+
method: http.MethodPost,
364+
path: "/openai/v1/chat/completions",
365+
body: []byte(`{"model": "base", "messages": [{"content":"some stuff"}]}`),
366+
367+
expReqHeaders: http.Header{
368+
"api-key": {"abcd1234"},
369+
},
370+
expReqPath: "/openai/deployments/gpt-35-turbo/chat/completions",
371+
// the 'model' field should have been removed.
372+
expReqBody: []byte(`{"messages":[{"content":"some stuff"}]}`),
373+
374+
expStatus: http.StatusOK,
375+
},
376+
{
377+
name: "azure - empty model",
378+
379+
settings: Settings{
380+
OpenAI: OpenAISettings{
381+
OrganizationID: "myOrg",
382+
Provider: openAIProviderAzure,
383+
AzureMapping: [][]string{
384+
{"gpt-3.5-turbo", "gpt-35-turbo"},
385+
},
386+
},
387+
},
388+
389+
apiKey: "abcd1234",
390+
391+
method: http.MethodPost,
392+
path: "/openai/v1/chat/completions",
393+
body: []byte(`{"messages": [{"content":"some stuff"}]}`),
394+
395+
expReqHeaders: http.Header{
396+
"api-key": {"abcd1234"},
397+
},
398+
expReqPath: "/openai/deployments/gpt-35-turbo/chat/completions",
399+
// the 'model' field should have been removed.
400+
expReqBody: []byte(`{"messages":[{"content":"some stuff"}]}`),
401+
402+
expStatus: http.StatusOK,
403+
},
296404
{
297405
name: "azure invalid deployment",
298406

@@ -310,7 +418,7 @@ func TestCallOpenAIProxy(t *testing.T) {
310418
method: http.MethodPost,
311419
path: "/openai/v1/chat/completions",
312420
// note no gpt-4 in AzureMapping.
313-
body: []byte(`{"model": "gpt-4", "messages": [{"content":"some stuff"}]}`),
421+
body: []byte(`{"model": "gpt-4-turbo", "messages": [{"content":"some stuff"}]}`),
314422

315423
expNilRequest: true,
316424

@@ -364,6 +472,31 @@ func TestCallOpenAIProxy(t *testing.T) {
364472
expReqPath: "/llm/openai/v1/chat/completions",
365473
expReqBody: []byte(`{"model": "gpt-3.5-turbo", "messages": [{"content":"some stuff"]}}`),
366474

475+
expStatus: http.StatusOK,
476+
},
477+
{
478+
name: "grafana-managed llm gateway - empty model",
479+
480+
settings: Settings{
481+
Tenant: "123",
482+
GrafanaComAPIKey: "abcd1234",
483+
OpenAI: OpenAISettings{
484+
Provider: openAIProviderGrafana,
485+
},
486+
},
487+
apiKey: "abcd1234",
488+
489+
method: http.MethodPost,
490+
path: "/openai/v1/chat/completions",
491+
body: []byte(`{"messages": [{"content":"some stuff"}]}`),
492+
493+
expReqHeaders: http.Header{
494+
"Authorization": {"Bearer 123:abcd1234"},
495+
"X-Scope-OrgID": {"123"},
496+
},
497+
expReqPath: "/llm/openai/v1/chat/completions",
498+
expReqBody: []byte(`{"model": "gpt-3.5-turbo", "messages": [{"content":"some stuff"]}}`),
499+
367500
expStatus: http.StatusOK,
368501
},
369502
} {

packages/grafana-llm-app/pkg/plugin/settings.go

+33
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,36 @@ type OpenAISettings struct {
4242
apiKey string
4343
}
4444

45+
type ModelMapping struct {
46+
Model Model `json:"model"`
47+
Name string `json:"name"`
48+
}
49+
50+
type ModelSettings struct {
51+
// Default model to use when no model is defined, or the model is not found.
52+
Default Model `json:"default"`
53+
54+
// Mapping is mapping from our abstract model names to the provider's model names.
55+
Mapping map[Model]string `json:"mapping"`
56+
}
57+
58+
func (c ModelSettings) getModel(model Model) string {
59+
// Helper function to get the name of a model.
60+
if name, ok := c.Mapping[model]; ok {
61+
return name
62+
}
63+
// If the model is not found, return the default model.
64+
return c.getModel(c.Default)
65+
}
66+
67+
var DEFAULT_MODEL_SETTINGS = &ModelSettings{
68+
Default: ModelBase,
69+
Mapping: map[Model]string{
70+
ModelBase: "gpt-3.5-turbo",
71+
ModelLarge: "gpt-4-turbo",
72+
},
73+
}
74+
4575
// LLMGatewaySettings contains the configuration for the Grafana Managed Key LLM solution.
4676
type LLMGatewaySettings struct {
4777
// This is the URL of the LLM endpoint of the machine learning backend which proxies
@@ -72,6 +102,9 @@ type Settings struct {
72102
// VectorDB settings. May rely on OpenAI settings.
73103
Vector vector.VectorSettings `json:"vector"`
74104

105+
// Models contains the user-specified models.
106+
Models *ModelSettings `json:"models"`
107+
75108
// LLMGateway provides Grafana-managed OpenAI.
76109
LLMGateway LLMGatewaySettings `json:"llmGateway"`
77110
}

0 commit comments

Comments
 (0)