Skip to content

Commit 8b78181

Browse files
committed
Add the code
1 parent bfb123e commit 8b78181

14 files changed

+698
-0
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@
1919

2020
# Go workspace file
2121
go.work
22+
23+
.idea/
24+
*.iml

chat.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package openrouter
2+
3+
import (
4+
"context"
5+
"errors"
6+
"net/http"
7+
)
8+
9+
// Chat message role defined by the Sensa API.
10+
11+
type ModelName string
12+
13+
const (
14+
ChatMessageRoleUser = "user"
15+
ChatMessageRoleSystem = "system"
16+
ChatMessageRoleAssistant = "assistant"
17+
)
18+
19+
var (
20+
ErrChatCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateChatCompletionStream") //nolint:lll
21+
ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method") //nolint:lll
22+
)
23+
24+
// CreateChatCompletion — API call to Create a completion for the chat message.
25+
func (c *Client) CreateChatCompletion(
26+
ctx context.Context,
27+
request *ChatCompletionRequest,
28+
) (response *ChatCompletionResponse, err error) {
29+
if request.Stream {
30+
err = ErrChatCompletionStreamNotSupported
31+
return
32+
}
33+
34+
urlSuffix := "/chat/completions"
35+
request.Model = wrapperModels[request.Model]
36+
if !checkSupportsModel(request.Model) {
37+
err = ErrCompletionUnsupportedModel
38+
return
39+
}
40+
41+
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
42+
if err != nil {
43+
return
44+
}
45+
46+
err = c.sendRequest(req, &response)
47+
return
48+
}

chat_stream.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package openrouter
2+
3+
import (
4+
"bufio"
5+
"context"
6+
utils "github.com/casibase/go-openrouter/internal"
7+
)
8+
9+
type ChatCompletionStream struct {
10+
streamReader
11+
}
12+
13+
// CreateChatCompletionStream — API call to create a chat completion w/ streaming
14+
// support. It sets whether to stream back partial progress. If set, tokens will be
15+
// sent as data-only server-sent events as they become available, with the
16+
// stream terminated by a data: [DONE] message.
17+
func (c *Client) CreateChatCompletionStream(
18+
ctx context.Context,
19+
request *ChatCompletionRequest,
20+
) (stream *ChatCompletionStream, err error) {
21+
urlSuffix := "/chat/completions"
22+
request.Model = wrapperModels[request.Model]
23+
if !checkSupportsModel(request.Model) {
24+
err = ErrCompletionUnsupportedModel
25+
return
26+
}
27+
request.Stream = true
28+
req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request)
29+
if err != nil {
30+
return
31+
}
32+
33+
resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
34+
if err != nil {
35+
return
36+
}
37+
if isFailureStatusCode(resp) {
38+
return nil, c.handleErrorResp(resp)
39+
}
40+
41+
stream = &ChatCompletionStream{
42+
streamReader: streamReader{
43+
emptyMessagesLimit: c.config.EmptyMessagesLimit,
44+
reader: bufio.NewReader(resp.Body),
45+
response: resp,
46+
errAccumulator: utils.NewErrorAccumulator(),
47+
unmarshaler: &utils.JSONUnmarshaler{},
48+
},
49+
}
50+
return
51+
}

chat_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package openrouter
2+
3+
import (
4+
"context"
5+
"testing"
6+
)
7+
8+
func TestClient_CreateChatCompletion(t *testing.T) {
9+
client, _ := NewClient("", "", "")
10+
11+
req := &ChatCompletionRequest{
12+
Model: "claude-2",
13+
Messages: []ChatCompletionMessage{
14+
{
15+
Role: ChatMessageRoleSystem,
16+
Content: "You are a helpful assistant.",
17+
},
18+
{
19+
Role: ChatMessageRoleUser,
20+
Content: "what is today",
21+
},
22+
},
23+
Stream: false,
24+
Temperature: nil,
25+
TopP: nil,
26+
}
27+
28+
t.Log(client.CreateChatCompletion(context.Background(), req))
29+
//
30+
//r, err := client.CreateChatCompletionStream(context.Background(), req)
31+
//if err != nil {
32+
// t.Error(err)
33+
//}
34+
//t.Log(r)
35+
//for {
36+
// fmt.Println(1)
37+
// r, err := r.Recv()
38+
// if err != nil {
39+
// fmt.Println(err.Error())
40+
// if errors.Is(err, io.EOF) {
41+
// fmt.Println(1)
42+
// break
43+
// }
44+
// t.Error(err)
45+
// }
46+
// t.Logf("%#v", r.Choices)
47+
//}
48+
}

client.go

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
package openrouter
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
utils "github.com/casibase/go-openrouter/internal"
8+
"io"
9+
"net/http"
10+
)
11+
12+
type Client struct {
13+
config ClientConfig
14+
15+
requestBuilder utils.RequestBuilder
16+
}
17+
18+
func NewClient(auth, xTitle, httpReferer string) (*Client, error) {
19+
config, err := DefaultConfig(auth, xTitle, httpReferer)
20+
if err != nil {
21+
return nil, err
22+
}
23+
return NewClientWithConfig(config), nil
24+
}
25+
26+
func NewClientWithConfig(config ClientConfig) *Client {
27+
return &Client{
28+
config: config,
29+
requestBuilder: utils.NewRequestBuilder(),
30+
}
31+
}
32+
33+
func (c *Client) sendRequest(req *http.Request, v any) error {
34+
req.Header.Set("Accept", "application/json; charset=utf-8")
35+
36+
// Check whether Content-Type is already set, Upload Files API requires
37+
// Content-Type == multipart/form-data
38+
contentType := req.Header.Get("Content-Type")
39+
if contentType == "" {
40+
req.Header.Set("Content-Type", "application/json; charset=utf-8")
41+
}
42+
43+
c.setCommonHeaders(req)
44+
45+
res, err := c.config.HTTPClient.Do(req)
46+
if err != nil {
47+
return err
48+
}
49+
defer res.Body.Close()
50+
51+
if isFailureStatusCode(res) {
52+
return c.handleErrorResp(res)
53+
}
54+
55+
return decodeResponse(res.Body, v)
56+
}
57+
58+
func (c *Client) setCommonHeaders(req *http.Request) {
59+
req.Header.Set("HTTP-Referer", c.config.HttpReferer)
60+
req.Header.Set("X-Title", c.config.XTitle)
61+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
62+
}
63+
64+
func isFailureStatusCode(resp *http.Response) bool {
65+
return resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest
66+
}
67+
68+
func decodeResponse(body io.Reader, v any) error {
69+
if v == nil {
70+
return nil
71+
}
72+
73+
if result, ok := v.(*string); ok {
74+
return decodeString(body, result)
75+
}
76+
return json.NewDecoder(body).Decode(v)
77+
}
78+
79+
func decodeString(body io.Reader, output *string) error {
80+
b, err := io.ReadAll(body)
81+
if err != nil {
82+
return err
83+
}
84+
*output = string(b)
85+
return nil
86+
}
87+
88+
// fullURL returns full URL for request.
89+
// args[0] is model name, if API type is Azure, model name is required to get deployment name.
90+
func (c *Client) fullURL(suffix string) string {
91+
return fmt.Sprintf("%s%s", c.config.BaseURL, suffix)
92+
}
93+
94+
func (c *Client) newStreamRequest(
95+
ctx context.Context,
96+
method string,
97+
urlSuffix string,
98+
body any) (*http.Request, error) {
99+
req, err := c.requestBuilder.Build(ctx, method, c.fullURL(urlSuffix), body)
100+
if err != nil {
101+
return nil, err
102+
}
103+
104+
req.Header.Set("Content-Type", "application/json")
105+
req.Header.Set("Accept", "text/event-stream")
106+
req.Header.Set("Cache-Control", "no-cache")
107+
req.Header.Set("Connection", "keep-alive")
108+
109+
c.setCommonHeaders(req)
110+
return req, nil
111+
}
112+
113+
func (c *Client) handleErrorResp(resp *http.Response) error {
114+
var errRes ErrorResponse
115+
116+
err := json.NewDecoder(resp.Body).Decode(&errRes)
117+
if err != nil || errRes.Error == nil {
118+
reqErr := &RequestError{
119+
HTTPStatusCode: resp.StatusCode,
120+
Err: err,
121+
}
122+
if errRes.Error != nil {
123+
reqErr.Err = errRes.Error
124+
}
125+
return reqErr
126+
}
127+
128+
errRes.Error.HTTPStatusCode = resp.StatusCode
129+
return errRes.Error
130+
}

config.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package openrouter
2+
3+
import (
4+
"net/http"
5+
)
6+
7+
const (
8+
routerAPIURLv1 = "https://openrouter.ai/api/v1"
9+
defaultEmptyMessagesLimit uint = 300
10+
)
11+
12+
// ClientConfig is a configuration of a client.
13+
// XTitle、HttpRefer your own site url
14+
type ClientConfig struct {
15+
authToken string
16+
XTitle string
17+
HttpReferer string
18+
BaseURL string
19+
HTTPClient *http.Client
20+
EmptyMessagesLimit uint
21+
}
22+
23+
func DefaultConfig(auth, xTitle, httpReferer string) (ClientConfig, error) {
24+
return ClientConfig{
25+
authToken: auth,
26+
HTTPClient: &http.Client{},
27+
XTitle: xTitle,
28+
HttpReferer: httpReferer,
29+
BaseURL: routerAPIURLv1,
30+
EmptyMessagesLimit: defaultEmptyMessagesLimit,
31+
}, nil
32+
}
33+
34+
func (c ClientConfig) WithHttpClientConfig(client *http.Client) ClientConfig {
35+
c.HTTPClient = client
36+
return c
37+
}

error.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package openrouter
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
)
7+
8+
// APIError provides error information returned by the OpenAI API.
9+
type APIError struct {
10+
Code any `json:"code,omitempty"`
11+
Message string `json:"message"`
12+
Details any `json:"details"`
13+
HTTPStatusCode int `json:"-"`
14+
}
15+
16+
// RequestError provides informations about generic request errors.
17+
type RequestError struct {
18+
HTTPStatusCode int
19+
Err error
20+
}
21+
22+
type ErrorResponse struct {
23+
Error *APIError `json:"error,omitempty"`
24+
}
25+
26+
func (e *APIError) Error() string {
27+
if e.HTTPStatusCode > 0 {
28+
return fmt.Sprintf("error, status code: %d, message: %s", e.HTTPStatusCode, e.Message)
29+
}
30+
31+
return e.Message
32+
}
33+
34+
func (e *APIError) UnmarshalJSON(data []byte) (err error) {
35+
var rawMap map[string]json.RawMessage
36+
err = json.Unmarshal(data, &rawMap)
37+
if err != nil {
38+
return
39+
}
40+
41+
err = json.Unmarshal(rawMap["message"], &e.Message)
42+
if err != nil {
43+
return
44+
}
45+
46+
if _, ok := rawMap["code"]; !ok {
47+
return nil
48+
}
49+
50+
// if the api returned a number, we need to force an integer
51+
// since the json package defaults to float64
52+
var intCode int
53+
err = json.Unmarshal(rawMap["code"], &intCode)
54+
if err == nil {
55+
e.Code = intCode
56+
return nil
57+
}
58+
59+
return json.Unmarshal(rawMap["code"], &e.Code)
60+
}
61+
62+
func (e *RequestError) Error() string {
63+
return fmt.Sprintf("error, status code: %d, message: %s", e.HTTPStatusCode, e.Err)
64+
}
65+
66+
func (e *RequestError) Unwrap() error {
67+
return e.Err
68+
}

go.mod

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
module github.com/casibase/go-openrouter
2+
3+
go 1.19

0 commit comments

Comments
 (0)