Skip to content

Commit 4507ede

Browse files
feat(plugin): Use gRPC interface for codegen plugin communication (#2930)
* feat(plugin): Use gRPC interface for codegen plugin communication * rename proto rpc service and messages * make invoke methods more generic * remove vtproto and add regular grpc buf plugin
1 parent a225849 commit 4507ede

23 files changed

+346
-6986
lines changed

buf.gen.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@ plugins:
55
- plugin: buf.build/protocolbuffers/go:v1.30.0
66
out: internal
77
opt: paths=source_relative
8-
- plugin: buf.build/community/planetscale-vtprotobuf:v0.4.0
8+
- plugin: buf.build/grpc/go:v1.3.0
99
out: internal
1010
opt: paths=source_relative

cmd/sqlc-gen-json/main.go

+4-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99

1010
"github.com/sqlc-dev/sqlc/internal/codegen/json"
1111
"github.com/sqlc-dev/sqlc/internal/plugin"
12+
"google.golang.org/protobuf/proto"
1213
)
1314

1415
func main() {
@@ -19,19 +20,19 @@ func main() {
1920
}
2021

2122
func run() error {
22-
var req plugin.CodeGenRequest
23+
var req plugin.GenerateRequest
2324
reqBlob, err := io.ReadAll(os.Stdin)
2425
if err != nil {
2526
return err
2627
}
27-
if err := req.UnmarshalVT(reqBlob); err != nil {
28+
if err := proto.Unmarshal(reqBlob, &req); err != nil {
2829
return err
2930
}
3031
resp, err := json.Generate(context.Background(), &req)
3132
if err != nil {
3233
return err
3334
}
34-
respBlob, err := resp.MarshalVT()
35+
respBlob, err := proto.Marshal(resp)
3536
if err != nil {
3637
return err
3738
}

internal/cmd/generate.go

+5-3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"sync"
1515

1616
"golang.org/x/sync/errgroup"
17+
"google.golang.org/grpc"
1718
"google.golang.org/grpc/status"
1819

1920
"github.com/sqlc-dev/sqlc/internal/codegen/golang"
@@ -380,10 +381,10 @@ func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.C
380381
return c.Result(), false
381382
}
382383

383-
func codegen(ctx context.Context, combo config.CombinedSettings, sql outPair, result *compiler.Result) (string, *plugin.CodeGenResponse, error) {
384+
func codegen(ctx context.Context, combo config.CombinedSettings, sql outPair, result *compiler.Result) (string, *plugin.GenerateResponse, error) {
384385
defer trace.StartRegion(ctx, "codegen").End()
385386
req := codeGenRequest(result, combo)
386-
var handler ext.Handler
387+
var handler grpc.ClientConnInterface
387388
var out string
388389
switch {
389390
case sql.Plugin != nil:
@@ -453,6 +454,7 @@ func codegen(ctx context.Context, combo config.CombinedSettings, sql outPair, re
453454
default:
454455
return "", nil, fmt.Errorf("missing language backend")
455456
}
456-
resp, err := handler.Generate(ctx, req)
457+
client := plugin.NewCodegenServiceClient(handler)
458+
resp, err := client.Generate(ctx, req)
457459
return out, resp, err
458460
}

internal/cmd/shim.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,8 @@ func pluginQueryParam(p compiler.Parameter) *plugin.Parameter {
223223
}
224224
}
225225

226-
func codeGenRequest(r *compiler.Result, settings config.CombinedSettings) *plugin.CodeGenRequest {
227-
return &plugin.CodeGenRequest{
226+
func codeGenRequest(r *compiler.Result, settings config.CombinedSettings) *plugin.GenerateRequest {
227+
return &plugin.GenerateRequest{
228228
Settings: pluginSettings(r, settings),
229229
Catalog: pluginCatalog(r.Catalog),
230230
Queries: pluginQueries(r),

internal/cmd/vet.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ func (c *checker) checkSQL(ctx context.Context, s config.SQL) error {
629629
return nil
630630
}
631631

632-
func vetConfig(req *plugin.CodeGenRequest) *vet.Config {
632+
func vetConfig(req *plugin.GenerateRequest) *vet.Config {
633633
return &vet.Config{
634634
Version: req.Settings.Version,
635635
Engine: req.Settings.Engine,

internal/codegen/golang/gen.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ func (t *tmplCtx) codegenQueryRetval(q Query) (string, error) {
103103
}
104104
}
105105

106-
func Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error) {
106+
func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.GenerateResponse, error) {
107107
options, err := opts.Parse(req)
108108
if err != nil {
109109
return nil, err
@@ -127,7 +127,7 @@ func Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenR
127127
return generate(req, options, enums, structs, queries)
128128
}
129129

130-
func generate(req *plugin.CodeGenRequest, options *opts.Options, enums []Enum, structs []Struct, queries []Query) (*plugin.CodeGenResponse, error) {
130+
func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum, structs []Struct, queries []Query) (*plugin.GenerateResponse, error) {
131131
i := &importer{
132132
Options: options,
133133
Queries: queries,
@@ -282,7 +282,7 @@ func generate(req *plugin.CodeGenRequest, options *opts.Options, enums []Enum, s
282282
return nil, err
283283
}
284284
}
285-
resp := plugin.CodeGenResponse{}
285+
resp := plugin.GenerateResponse{}
286286

287287
for filename, code := range output {
288288
resp.Files = append(resp.Files, &plugin.File{

internal/codegen/golang/go_type.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import (
88
"github.com/sqlc-dev/sqlc/internal/plugin"
99
)
1010

11-
func addExtraGoStructTags(tags map[string]string, req *plugin.CodeGenRequest, options *opts.Options, col *plugin.Column) {
11+
func addExtraGoStructTags(tags map[string]string, req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) {
1212
for _, override := range options.Overrides {
1313
oride := override.ShimOverride
1414
if oride.GoType.StructTags == nil {
@@ -33,7 +33,7 @@ func addExtraGoStructTags(tags map[string]string, req *plugin.CodeGenRequest, op
3333
}
3434
}
3535

36-
func goType(req *plugin.CodeGenRequest, options *opts.Options, col *plugin.Column) string {
36+
func goType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) string {
3737
// Check if the column's type has been overridden
3838
for _, override := range options.Overrides {
3939
oride := override.ShimOverride
@@ -63,7 +63,7 @@ func goType(req *plugin.CodeGenRequest, options *opts.Options, col *plugin.Colum
6363
return typ
6464
}
6565

66-
func goInnerType(req *plugin.CodeGenRequest, options *opts.Options, col *plugin.Column) string {
66+
func goInnerType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) string {
6767
columnType := sdk.DataType(col.Type)
6868
notNull := col.NotNull || col.IsArray
6969

internal/codegen/golang/mysql_type.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99
"github.com/sqlc-dev/sqlc/internal/plugin"
1010
)
1111

12-
func mysqlType(req *plugin.CodeGenRequest, options *opts.Options, col *plugin.Column) string {
12+
func mysqlType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) string {
1313
columnType := sdk.DataType(col.Type)
1414
notNull := col.NotNull || col.IsArray
1515
unsigned := col.Unsigned

internal/codegen/golang/opts/options.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ type GlobalOptions struct {
4747
Rename map[string]string `json:"rename,omitempty" yaml:"rename"`
4848
}
4949

50-
func Parse(req *plugin.CodeGenRequest) (*Options, error) {
50+
func Parse(req *plugin.GenerateRequest) (*Options, error) {
5151
options, err := parseOpts(req)
5252
if err != nil {
5353
return nil, err
@@ -68,7 +68,7 @@ func Parse(req *plugin.CodeGenRequest) (*Options, error) {
6868
return options, nil
6969
}
7070

71-
func parseOpts(req *plugin.CodeGenRequest) (*Options, error) {
71+
func parseOpts(req *plugin.GenerateRequest) (*Options, error) {
7272
var options Options
7373
if len(req.PluginOptions) == 0 {
7474
return &options, nil
@@ -91,7 +91,7 @@ func parseOpts(req *plugin.CodeGenRequest) (*Options, error) {
9191
return &options, nil
9292
}
9393

94-
func parseGlobalOpts(req *plugin.CodeGenRequest) (*GlobalOptions, error) {
94+
func parseGlobalOpts(req *plugin.GenerateRequest) (*GlobalOptions, error) {
9595
var options GlobalOptions
9696
if len(req.GlobalOptions) == 0 {
9797
return &options, nil

internal/codegen/golang/opts/override.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ func (o *Override) Matches(n *plugin.Identifier, defaultSchema string) bool {
7676
return true
7777
}
7878

79-
func (o *Override) parse(req *plugin.CodeGenRequest) (err error) {
79+
func (o *Override) parse(req *plugin.GenerateRequest) (err error) {
8080
// validate deprecated postgres_type field
8181
if o.Deprecated_PostgresType != "" {
8282
fmt.Fprintf(os.Stderr, "WARNING: \"postgres_type\" is deprecated. Instead, use \"db_type\" to specify a type override.\n")

internal/codegen/golang/opts/shim.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ type ShimOverride struct {
2121
GoType *ShimGoType
2222
}
2323

24-
func shimOverride(req *plugin.CodeGenRequest, o *Override) *ShimOverride {
24+
func shimOverride(req *plugin.GenerateRequest, o *Override) *ShimOverride {
2525
var column string
2626
var table plugin.Identifier
2727

internal/codegen/golang/postgresql_type.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ func parseIdentifierString(name string) (*plugin.Identifier, error) {
3434
}
3535
}
3636

37-
func postgresType(req *plugin.CodeGenRequest, options *opts.Options, col *plugin.Column) string {
37+
func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) string {
3838
columnType := sdk.DataType(col.Type)
3939
notNull := col.NotNull || col.IsArray
4040
driver := parseDriver(options.SqlPackage)

internal/codegen/golang/result.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import (
1212
"github.com/sqlc-dev/sqlc/internal/plugin"
1313
)
1414

15-
func buildEnums(req *plugin.CodeGenRequest, options *opts.Options) []Enum {
15+
func buildEnums(req *plugin.GenerateRequest, options *opts.Options) []Enum {
1616
var enums []Enum
1717
for _, schema := range req.Catalog.Schemas {
1818
if schema.Name == "pg_catalog" || schema.Name == "information_schema" {
@@ -59,7 +59,7 @@ func buildEnums(req *plugin.CodeGenRequest, options *opts.Options) []Enum {
5959
return enums
6060
}
6161

62-
func buildStructs(req *plugin.CodeGenRequest, options *opts.Options) []Struct {
62+
func buildStructs(req *plugin.GenerateRequest, options *opts.Options) []Struct {
6363
var structs []Struct
6464
for _, schema := range req.Catalog.Schemas {
6565
if schema.Name == "pg_catalog" || schema.Name == "information_schema" {
@@ -182,7 +182,7 @@ func argName(name string) string {
182182
return out
183183
}
184184

185-
func buildQueries(req *plugin.CodeGenRequest, options *opts.Options, structs []Struct) ([]Query, error) {
185+
func buildQueries(req *plugin.GenerateRequest, options *opts.Options, structs []Struct) ([]Query, error) {
186186
qs := make([]Query, 0, len(req.Queries))
187187
for _, query := range req.Queries {
188188
if query.Name == "" {
@@ -332,7 +332,7 @@ func putOutColumns(query *plugin.Query) bool {
332332
// JSON tags: count, count_2, count_2
333333
//
334334
// This is unlikely to happen, so don't fix it yet
335-
func columnsToStruct(req *plugin.CodeGenRequest, options *opts.Options, name string, columns []goColumn, useID bool) (*Struct, error) {
335+
func columnsToStruct(req *plugin.GenerateRequest, options *opts.Options, name string, columns []goColumn, useID bool) (*Struct, error) {
336336
gs := Struct{
337337
Name: name,
338338
}

internal/codegen/golang/sqlite_type.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99
"github.com/sqlc-dev/sqlc/internal/plugin"
1010
)
1111

12-
func sqliteType(req *plugin.CodeGenRequest, col *plugin.Column) string {
12+
func sqliteType(req *plugin.GenerateRequest, col *plugin.Column) string {
1313
dt := strings.ToLower(sdk.DataType(col.Type))
1414
notNull := col.NotNull || col.IsArray
1515

internal/codegen/json/gen.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import (
1111
"github.com/sqlc-dev/sqlc/internal/plugin"
1212
)
1313

14-
func parseOptions(req *plugin.CodeGenRequest) (*opts, error) {
14+
func parseOptions(req *plugin.GenerateRequest) (*opts, error) {
1515
if len(req.PluginOptions) == 0 {
1616
return new(opts), nil
1717
}
@@ -25,7 +25,7 @@ func parseOptions(req *plugin.CodeGenRequest) (*opts, error) {
2525
return options, nil
2626
}
2727

28-
func Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error) {
28+
func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.GenerateResponse, error) {
2929
options, err := parseOptions(req)
3030
if err != nil {
3131
return nil, err
@@ -57,7 +57,7 @@ func Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenR
5757
if err != nil {
5858
return nil, err
5959
}
60-
return &plugin.CodeGenResponse{
60+
return &plugin.GenerateResponse{
6161
Files: []*plugin.File{
6262
{
6363
Name: filename,

internal/ext/handler.go

+33-4
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,51 @@ package ext
22

33
import (
44
"context"
5+
"fmt"
6+
7+
"google.golang.org/grpc"
8+
"google.golang.org/grpc/codes"
9+
"google.golang.org/grpc/status"
510

611
"github.com/sqlc-dev/sqlc/internal/plugin"
712
)
813

914
type Handler interface {
10-
Generate(context.Context, *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error)
15+
Generate(context.Context, *plugin.GenerateRequest) (*plugin.GenerateResponse, error)
16+
17+
Invoke(ctx context.Context, method string, args any, reply any, opts ...grpc.CallOption) error
18+
NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error)
1119
}
1220

1321
type wrapper struct {
14-
fn func(context.Context, *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error)
22+
fn func(context.Context, *plugin.GenerateRequest) (*plugin.GenerateResponse, error)
1523
}
1624

17-
func (w *wrapper) Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error) {
25+
func (w *wrapper) Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.GenerateResponse, error) {
1826
return w.fn(ctx, req)
1927
}
2028

21-
func HandleFunc(fn func(context.Context, *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error)) Handler {
29+
func (w *wrapper) Invoke(ctx context.Context, method string, args any, reply any, opts ...grpc.CallOption) error {
30+
req, ok := args.(*plugin.GenerateRequest)
31+
if !ok {
32+
return fmt.Errorf("args isn't a GenerateRequest")
33+
}
34+
resp, ok := reply.(*plugin.GenerateResponse)
35+
if !ok {
36+
return fmt.Errorf("reply isn't a GenerateResponse")
37+
}
38+
res, err := w.Generate(ctx, req)
39+
if err != nil {
40+
return err
41+
}
42+
resp.Files = res.Files
43+
return nil
44+
}
45+
46+
func (w *wrapper) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
47+
return nil, status.Error(codes.Unimplemented, "")
48+
}
49+
50+
func HandleFunc(fn func(context.Context, *plugin.GenerateRequest) (*plugin.GenerateResponse, error)) Handler {
2251
return &wrapper{fn}
2352
}

0 commit comments

Comments
 (0)