Skip to content

chore: update for dataset rewrite #83

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 42 additions & 127 deletions datasets.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@ package gptscript

import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"os"
)

type DatasetElementMeta struct {
Expand All @@ -15,7 +13,8 @@ type DatasetElementMeta struct {

type DatasetElement struct {
DatasetElementMeta `json:",inline"`
Contents []byte `json:"contents"`
Contents string `json:"contents"`
BinaryContents []byte `json:"binaryContents"`
}

type DatasetMeta struct {
Expand All @@ -24,34 +23,17 @@ type DatasetMeta struct {
Description string `json:"description"`
}

type Dataset struct {
DatasetMeta `json:",inline"`
BaseDir string `json:"baseDir,omitempty"`
Elements map[string]DatasetElementMeta `json:"elements"`
}

type datasetRequest struct {
Input string `json:"input"`
WorkspaceID string `json:"workspaceID"`
DatasetToolRepo string `json:"datasetToolRepo"`
Env []string `json:"env"`
}

type createDatasetArgs struct {
Name string `json:"datasetName"`
Description string `json:"datasetDescription"`
}

type addDatasetElementArgs struct {
DatasetID string `json:"datasetID"`
ElementName string `json:"elementName"`
ElementDescription string `json:"elementDescription"`
ElementContent string `json:"elementContent"`
Input string `json:"input"`
DatasetTool string `json:"datasetTool"`
Env []string `json:"env"`
}

type addDatasetElementsArgs struct {
DatasetID string `json:"datasetID"`
Elements []DatasetElement `json:"elements"`
DatasetID string `json:"datasetID"`
Name string `json:"name"`
Description string `json:"description"`
Elements []DatasetElement `json:"elements"`
}

type listDatasetElementArgs struct {
Expand All @@ -60,19 +42,14 @@ type listDatasetElementArgs struct {

type getDatasetElementArgs struct {
DatasetID string `json:"datasetID"`
Element string `json:"element"`
Element string `json:"name"`
}

func (g *GPTScript) ListDatasets(ctx context.Context, workspaceID string) ([]DatasetMeta, error) {
if workspaceID == "" {
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
}

func (g *GPTScript) ListDatasets(ctx context.Context) ([]DatasetMeta, error) {
out, err := g.runBasicCommand(ctx, "datasets", datasetRequest{
Input: "{}",
WorkspaceID: workspaceID,
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
Env: g.globalOpts.Env,
Input: "{}",
DatasetTool: g.globalOpts.DatasetTool,
Env: g.globalOpts.Env,
})
if err != nil {
return nil, err
Expand All @@ -85,98 +62,42 @@ func (g *GPTScript) ListDatasets(ctx context.Context, workspaceID string) ([]Dat
return datasets, nil
}

func (g *GPTScript) CreateDataset(ctx context.Context, workspaceID, name, description string) (Dataset, error) {
if workspaceID == "" {
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
}

args := createDatasetArgs{
Name: name,
Description: description,
}
argsJSON, err := json.Marshal(args)
if err != nil {
return Dataset{}, fmt.Errorf("failed to marshal dataset args: %w", err)
}

out, err := g.runBasicCommand(ctx, "datasets/create", datasetRequest{
Input: string(argsJSON),
WorkspaceID: workspaceID,
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
Env: g.globalOpts.Env,
})
if err != nil {
return Dataset{}, err
}

var dataset Dataset
if err = json.Unmarshal([]byte(out), &dataset); err != nil {
return Dataset{}, err
}
return dataset, nil
type DatasetOptions struct {
Name, Description string
}

func (g *GPTScript) AddDatasetElement(ctx context.Context, workspaceID, datasetID, elementName, elementDescription string, elementContent []byte) (DatasetElementMeta, error) {
if workspaceID == "" {
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
}

args := addDatasetElementArgs{
DatasetID: datasetID,
ElementName: elementName,
ElementDescription: elementDescription,
ElementContent: base64.StdEncoding.EncodeToString(elementContent),
}
argsJSON, err := json.Marshal(args)
if err != nil {
return DatasetElementMeta{}, fmt.Errorf("failed to marshal element args: %w", err)
}

out, err := g.runBasicCommand(ctx, "datasets/add-element", datasetRequest{
Input: string(argsJSON),
WorkspaceID: workspaceID,
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
Env: g.globalOpts.Env,
})
if err != nil {
return DatasetElementMeta{}, err
}

var element DatasetElementMeta
if err = json.Unmarshal([]byte(out), &element); err != nil {
return DatasetElementMeta{}, err
}
return element, nil
func (g *GPTScript) CreateDatasetWithElements(ctx context.Context, elements []DatasetElement, options ...DatasetOptions) (string, error) {
return g.AddDatasetElements(ctx, "", elements, options...)
}

func (g *GPTScript) AddDatasetElements(ctx context.Context, workspaceID, datasetID string, elements []DatasetElement) error {
if workspaceID == "" {
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
}

func (g *GPTScript) AddDatasetElements(ctx context.Context, datasetID string, elements []DatasetElement, options ...DatasetOptions) (string, error) {
args := addDatasetElementsArgs{
DatasetID: datasetID,
Elements: elements,
}

for _, opt := range options {
if opt.Name != "" {
args.Name = opt.Name
}
if opt.Description != "" {
args.Description = opt.Description
}
}

argsJSON, err := json.Marshal(args)
if err != nil {
return fmt.Errorf("failed to marshal element args: %w", err)
return "", fmt.Errorf("failed to marshal element args: %w", err)
}

_, err = g.runBasicCommand(ctx, "datasets/add-elements", datasetRequest{
Input: string(argsJSON),
WorkspaceID: workspaceID,
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
Env: g.globalOpts.Env,
return g.runBasicCommand(ctx, "datasets/add-elements", datasetRequest{
Input: string(argsJSON),
DatasetTool: g.globalOpts.DatasetTool,
Env: g.globalOpts.Env,
})
return err
}

func (g *GPTScript) ListDatasetElements(ctx context.Context, workspaceID, datasetID string) ([]DatasetElementMeta, error) {
if workspaceID == "" {
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
}

func (g *GPTScript) ListDatasetElements(ctx context.Context, datasetID string) ([]DatasetElementMeta, error) {
args := listDatasetElementArgs{
DatasetID: datasetID,
}
Expand All @@ -186,10 +107,9 @@ func (g *GPTScript) ListDatasetElements(ctx context.Context, workspaceID, datase
}

out, err := g.runBasicCommand(ctx, "datasets/list-elements", datasetRequest{
Input: string(argsJSON),
WorkspaceID: workspaceID,
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
Env: g.globalOpts.Env,
Input: string(argsJSON),
DatasetTool: g.globalOpts.DatasetTool,
Env: g.globalOpts.Env,
})
if err != nil {
return nil, err
Expand All @@ -202,11 +122,7 @@ func (g *GPTScript) ListDatasetElements(ctx context.Context, workspaceID, datase
return elements, nil
}

func (g *GPTScript) GetDatasetElement(ctx context.Context, workspaceID, datasetID, elementName string) (DatasetElement, error) {
if workspaceID == "" {
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
}

func (g *GPTScript) GetDatasetElement(ctx context.Context, datasetID, elementName string) (DatasetElement, error) {
args := getDatasetElementArgs{
DatasetID: datasetID,
Element: elementName,
Expand All @@ -217,10 +133,9 @@ func (g *GPTScript) GetDatasetElement(ctx context.Context, workspaceID, datasetI
}

out, err := g.runBasicCommand(ctx, "datasets/get-element", datasetRequest{
Input: string(argsJSON),
WorkspaceID: workspaceID,
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
Env: g.globalOpts.Env,
Input: string(argsJSON),
DatasetTool: g.globalOpts.DatasetTool,
Env: g.globalOpts.Env,
})
if err != nil {
return DatasetElement{}, err
Expand Down
74 changes: 48 additions & 26 deletions datasets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gptscript

import (
"context"
"os"
"testing"

"github.com/stretchr/testify/require"
Expand All @@ -11,66 +12,87 @@ func TestDatasets(t *testing.T) {
workspaceID, err := g.CreateWorkspace(context.Background(), "directory")
require.NoError(t, err)

client, err := NewGPTScript(GlobalOptions{
OpenAIAPIKey: os.Getenv("OPENAI_API_KEY"),
Env: append(os.Environ(), "GPTSCRIPT_WORKSPACE_ID="+workspaceID),
})
require.NoError(t, err)

defer func() {
_ = g.DeleteWorkspace(context.Background(), workspaceID)
}()

// Create a dataset
dataset, err := g.CreateDataset(context.Background(), workspaceID, "test-dataset", "This is a test dataset")
require.NoError(t, err)
require.Equal(t, "test-dataset", dataset.Name)
require.Equal(t, "This is a test dataset", dataset.Description)
require.Equal(t, 0, len(dataset.Elements))

// Add an element
elementMeta, err := g.AddDatasetElement(context.Background(), workspaceID, dataset.ID, "test-element", "This is a test element", []byte("This is the content"))
datasetID, err := client.CreateDatasetWithElements(context.Background(), []DatasetElement{
{
DatasetElementMeta: DatasetElementMeta{
Name: "test-element-1",
Description: "This is a test element 1",
},
Contents: "This is the content 1",
},
}, DatasetOptions{
Name: "test-dataset",
Description: "this is a test dataset",
})
require.NoError(t, err)
require.Equal(t, "test-element", elementMeta.Name)
require.Equal(t, "This is a test element", elementMeta.Description)

// Add two more
err = g.AddDatasetElements(context.Background(), workspaceID, dataset.ID, []DatasetElement{
// Add three more elements
_, err = client.AddDatasetElements(context.Background(), datasetID, []DatasetElement{
{
DatasetElementMeta: DatasetElementMeta{
Name: "test-element-2",
Description: "This is a test element 2",
},
Contents: []byte("This is the content 2"),
Contents: "This is the content 2",
},
{
DatasetElementMeta: DatasetElementMeta{
Name: "test-element-3",
Description: "This is a test element 3",
},
Contents: []byte("This is the content 3"),
Contents: "This is the content 3",
},
{
DatasetElementMeta: DatasetElementMeta{
Name: "binary-element",
Description: "this element has binary contents",
},
BinaryContents: []byte("binary contents"),
},
})
require.NoError(t, err)

// Get the first element
element, err := g.GetDatasetElement(context.Background(), workspaceID, dataset.ID, "test-element")
element, err := client.GetDatasetElement(context.Background(), datasetID, "test-element-1")
require.NoError(t, err)
require.Equal(t, "test-element", element.Name)
require.Equal(t, "This is a test element", element.Description)
require.Equal(t, []byte("This is the content"), element.Contents)
require.Equal(t, "test-element-1", element.Name)
require.Equal(t, "This is a test element 1", element.Description)
require.Equal(t, "This is the content 1", element.Contents)

// Get the third element
element, err = g.GetDatasetElement(context.Background(), workspaceID, dataset.ID, "test-element-3")
element, err = client.GetDatasetElement(context.Background(), datasetID, "test-element-3")
require.NoError(t, err)
require.Equal(t, "test-element-3", element.Name)
require.Equal(t, "This is a test element 3", element.Description)
require.Equal(t, []byte("This is the content 3"), element.Contents)
require.Equal(t, "This is the content 3", element.Contents)

// Get the binary element
element, err = client.GetDatasetElement(context.Background(), datasetID, "binary-element")
require.NoError(t, err)
require.Equal(t, "binary-element", element.Name)
require.Equal(t, "this element has binary contents", element.Description)
require.Equal(t, []byte("binary contents"), element.BinaryContents)

// List elements in the dataset
elements, err := g.ListDatasetElements(context.Background(), workspaceID, dataset.ID)
elements, err := client.ListDatasetElements(context.Background(), datasetID)
require.NoError(t, err)
require.Equal(t, 3, len(elements))
require.Equal(t, 4, len(elements))

// List datasets
datasets, err := g.ListDatasets(context.Background(), workspaceID)
datasets, err := client.ListDatasets(context.Background())
require.NoError(t, err)
require.Equal(t, 1, len(datasets))
require.Equal(t, datasetID, datasets[0].ID)
require.Equal(t, "test-dataset", datasets[0].Name)
require.Equal(t, "This is a test dataset", datasets[0].Description)
require.Equal(t, dataset.ID, datasets[0].ID)
require.Equal(t, "this is a test dataset", datasets[0].Description)
}
4 changes: 2 additions & 2 deletions opts.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ type GlobalOptions struct {
DefaultModelProvider string `json:"DefaultModelProvider"`
CacheDir string `json:"CacheDir"`
Env []string `json:"env"`
DatasetToolRepo string `json:"DatasetToolRepo"`
DatasetTool string `json:"DatasetTool"`
WorkspaceTool string `json:"WorkspaceTool"`
}

Expand Down Expand Up @@ -46,7 +46,7 @@ func completeGlobalOptions(opts ...GlobalOptions) GlobalOptions {
result.OpenAIBaseURL = firstSet(opt.OpenAIBaseURL, result.OpenAIBaseURL)
result.DefaultModel = firstSet(opt.DefaultModel, result.DefaultModel)
result.DefaultModelProvider = firstSet(opt.DefaultModelProvider, result.DefaultModelProvider)
result.DatasetToolRepo = firstSet(opt.DatasetToolRepo, result.DatasetToolRepo)
result.DatasetTool = firstSet(opt.DatasetTool, result.DatasetTool)
result.WorkspaceTool = firstSet(opt.WorkspaceTool, result.WorkspaceTool)
result.Env = append(result.Env, opt.Env...)
}
Expand Down
Loading