Skip to content
This repository was archived by the owner on Jun 14, 2019. It is now read-only.

Commit 395bcf3

Browse files
authored
Add insert select support (#39)
* add insert select support * refactor insert select * improve sort * update test * fix * hide fiddle sql tests * update README * update README
1 parent 377feed commit 395bcf3

13 files changed

+148
-62
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ Make sure you have installed Go 1.8+ and then:
1313

1414
```Go
1515
sql, args, err := builder.Insert(Eq{"c": 1, "d": 2}).Into("table1").ToSQL()
16+
17+
// INSERT INTO table1 SELECT * FROM table2
18+
sql, err := builder.Insert().Into("table1").Select().From("table2").ToBoundSQL()
19+
20+
// INSERT INTO table1 (a, b) SELECT b, c FROM table2
21+
sql, err = builder.Insert("a, b").Into("table1").Select("b, c").From("table2").ToBoundSQL()
1622
```
1723

1824
# Select

builder.go

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package builder
77
import (
88
sql2 "database/sql"
99
"fmt"
10+
"sort"
1011
)
1112

1213
type optype byte
@@ -49,14 +50,16 @@ type Builder struct {
4950
optype
5051
dialect string
5152
isNested bool
52-
tableName string
53+
into string
54+
from string
5355
subQuery *Builder
5456
cond Cond
5557
selects []string
5658
joins []join
5759
unions []union
5860
limitation *limit
59-
inserts Eq
61+
insertCols []string
62+
insertVals []interface{}
6063
updates []Eq
6164
orderBy string
6265
groupBy string
@@ -111,15 +114,15 @@ func (b *Builder) From(subject interface{}, alias ...string) *Builder {
111114
b.subQuery = subject.(*Builder)
112115

113116
if len(alias) > 0 {
114-
b.tableName = alias[0]
117+
b.from = alias[0]
115118
} else {
116119
b.isNested = true
117120
}
118121
case string:
119-
b.tableName = subject.(string)
122+
b.from = subject.(string)
120123

121124
if len(alias) > 0 {
122-
b.tableName = b.tableName + " " + alias[0]
125+
b.from = b.from + " " + alias[0]
123126
}
124127
}
125128

@@ -128,12 +131,15 @@ func (b *Builder) From(subject interface{}, alias ...string) *Builder {
128131

129132
// TableName returns the table name
130133
func (b *Builder) TableName() string {
131-
return b.tableName
134+
if b.optype == insertType {
135+
return b.into
136+
}
137+
return b.from
132138
}
133139

134140
// Into sets insert table name
135141
func (b *Builder) Into(tableName string) *Builder {
136-
b.tableName = tableName
142+
b.into = tableName
137143
return b
138144
}
139145

@@ -221,7 +227,9 @@ func (b *Builder) FullJoin(joinTable string, joinCond interface{}) *Builder {
221227
// Select sets select SQL
222228
func (b *Builder) Select(cols ...string) *Builder {
223229
b.selects = cols
224-
b.optype = selectType
230+
if b.optype == condType {
231+
b.optype = selectType
232+
}
225233
return b
226234
}
227235

@@ -238,8 +246,40 @@ func (b *Builder) Or(cond Cond) *Builder {
238246
}
239247

240248
// Insert sets insert SQL
241-
func (b *Builder) Insert(eq Eq) *Builder {
242-
b.inserts = eq
249+
func (b *Builder) Insert(eq ...interface{}) *Builder {
250+
if len(eq) > 0 {
251+
var paramType = -1
252+
for _, e := range eq {
253+
switch t := e.(type) {
254+
case Eq:
255+
if paramType == -1 {
256+
paramType = 0
257+
}
258+
if paramType != 0 {
259+
break
260+
}
261+
for k, v := range t {
262+
b.insertCols = append(b.insertCols, k)
263+
b.insertVals = append(b.insertVals, v)
264+
}
265+
case string:
266+
if paramType == -1 {
267+
paramType = 1
268+
}
269+
if paramType != 1 {
270+
break
271+
}
272+
b.insertCols = append(b.insertCols, t)
273+
}
274+
}
275+
}
276+
277+
if len(b.insertCols) == len(b.insertVals) {
278+
sort.Slice(b.insertVals, func(i, j int) bool {
279+
return b.insertCols[i] < b.insertCols[j]
280+
})
281+
sort.Strings(b.insertCols)
282+
}
243283
b.optype = insertType
244284
return b
245285
}

builder_delete.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ func Delete(conds ...Cond) *Builder {
1515
}
1616

1717
func (b *Builder) deleteWriteTo(w Writer) error {
18-
if len(b.tableName) <= 0 {
18+
if len(b.from) <= 0 {
1919
return ErrNoTableName
2020
}
2121

22-
if _, err := fmt.Fprintf(w, "DELETE FROM %s WHERE ", b.tableName); err != nil {
22+
if _, err := fmt.Fprintf(w, "DELETE FROM %s WHERE ", b.from); err != nil {
2323
return err
2424
}
2525

builder_insert.go

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,30 +10,49 @@ import (
1010
)
1111

1212
// Insert creates an insert Builder
13-
func Insert(eq Eq) *Builder {
13+
func Insert(eq ...interface{}) *Builder {
1414
builder := &Builder{cond: NewCond()}
15-
return builder.Insert(eq)
15+
return builder.Insert(eq...)
16+
}
17+
18+
func (b *Builder) insertSelectWriteTo(w Writer) error {
19+
if _, err := fmt.Fprintf(w, "INSERT INTO %s ", b.into); err != nil {
20+
return err
21+
}
22+
23+
if len(b.insertCols) > 0 {
24+
fmt.Fprintf(w, "(")
25+
for _, col := range b.insertCols {
26+
fmt.Fprintf(w, col)
27+
}
28+
fmt.Fprintf(w, ") ")
29+
}
30+
31+
return b.selectWriteTo(w)
1632
}
1733

1834
func (b *Builder) insertWriteTo(w Writer) error {
19-
if len(b.tableName) <= 0 {
35+
if len(b.into) <= 0 {
2036
return ErrNoTableName
2137
}
22-
if len(b.inserts) <= 0 {
38+
if len(b.insertCols) <= 0 && b.from == "" {
2339
return ErrNoColumnToInsert
2440
}
2541

26-
if _, err := fmt.Fprintf(w, "INSERT INTO %s (", b.tableName); err != nil {
42+
if b.into != "" && b.from != "" {
43+
return b.insertSelectWriteTo(w)
44+
}
45+
46+
if _, err := fmt.Fprintf(w, "INSERT INTO %s (", b.into); err != nil {
2747
return err
2848
}
2949

3050
var args = make([]interface{}, 0)
3151
var bs []byte
3252
var valBuffer = bytes.NewBuffer(bs)
33-
var i = 0
3453

35-
for _, col := range b.inserts.sortedKeys() {
36-
value := b.inserts[col]
54+
for i, col := range b.insertCols {
55+
value := b.insertVals[i]
3756
fmt.Fprint(w, col)
3857
if e, ok := value.(expr); ok {
3958
fmt.Fprintf(valBuffer, "(%s)", e.sql)
@@ -43,15 +62,14 @@ func (b *Builder) insertWriteTo(w Writer) error {
4362
args = append(args, value)
4463
}
4564

46-
if i != len(b.inserts)-1 {
65+
if i != len(b.insertCols)-1 {
4766
if _, err := fmt.Fprint(w, ","); err != nil {
4867
return err
4968
}
5069
if _, err := fmt.Fprint(valBuffer, ","); err != nil {
5170
return err
5271
}
5372
}
54-
i = i + 1
5573
}
5674

5775
if _, err := fmt.Fprint(w, ") Values ("); err != nil {

builder_insert_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Copyright 2018 The Xorm Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package builder
6+
7+
import (
8+
"testing"
9+
10+
"github.com/stretchr/testify/assert"
11+
)
12+
13+
func TestBuilderInsert(t *testing.T) {
14+
sql, err := Insert(Eq{"c": 1, "d": 2}).Into("table1").ToBoundSQL()
15+
assert.NoError(t, err)
16+
assert.EqualValues(t, "INSERT INTO table1 (c,d) Values (1,2)", sql)
17+
18+
sql, err = Insert(Eq{"c": 1, "d": Expr("SELECT b FROM t WHERE d=? LIMIT 1", 2)}).Into("table1").ToBoundSQL()
19+
assert.NoError(t, err)
20+
assert.EqualValues(t, "INSERT INTO table1 (c,d) Values (1,(SELECT b FROM t WHERE d=2 LIMIT 1))", sql)
21+
22+
sql, err = Insert(Eq{"c": 1, "d": 2}).ToBoundSQL()
23+
assert.Error(t, err)
24+
assert.EqualValues(t, ErrNoTableName, err)
25+
assert.EqualValues(t, "", sql)
26+
27+
sql, err = Insert(Eq{}).Into("table1").ToBoundSQL()
28+
assert.Error(t, err)
29+
assert.EqualValues(t, ErrNoColumnToInsert, err)
30+
assert.EqualValues(t, "", sql)
31+
}
32+
33+
func TestBuidlerInsert_Select(t *testing.T) {
34+
sql, err := Insert().Into("table1").Select().From("table2").ToBoundSQL()
35+
assert.NoError(t, err)
36+
assert.EqualValues(t, "INSERT INTO table1 SELECT * FROM table2", sql)
37+
38+
sql, err = Insert("a, b").Into("table1").Select("b, c").From("table2").ToBoundSQL()
39+
assert.NoError(t, err)
40+
assert.EqualValues(t, "INSERT INTO table1 (a, b) SELECT b, c FROM table2", sql)
41+
}

builder_limit.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ func (b *Builder) limitWriteTo(w Writer) error {
5656
case SQLITE, MYSQL, POSTGRES:
5757
// if type UNION, we need to write previous content back to current writer
5858
if b.optype == unionType {
59-
b.WriteTo(ow)
59+
if err := b.WriteTo(ow); err != nil {
60+
return err
61+
}
6062
}
6163

6264
if limit.offset == 0 {

builder_limit_test.go

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,7 @@
44

55
package builder
66

7-
import (
8-
"testing"
9-
10-
"github.com/stretchr/testify/assert"
11-
)
12-
7+
/*
138
func TestBuilder_Limit4Mssql(t *testing.T) {
149
sqlFromFile, err := readPreparationSQLFromFile("testdata/mssql_fiddle_data.sql")
1510
assert.NoError(t, err)
@@ -126,4 +121,4 @@ func TestBuilder_Limit4Oracle(t *testing.T) {
126121
assert.NoError(t, err)
127122
assert.EqualValues(t, "SELECT a,b,c FROM (SELECT a,b,c,ROWNUM RN FROM ((SELECT a,b,c FROM (SELECT * FROM (SELECT a,b,c,ROWNUM RN FROM table1 WHERE a<>'0' ORDER BY a ASC) at WHERE at.RN<=15) att WHERE att.RN>10) UNION ALL (SELECT a,b,c FROM (SELECT a,b,c,ROWNUM RN FROM table1 WHERE b<>'48' ORDER BY a DESC) at WHERE at.RN<=10)) at) at WHERE at.RN<=3", sql)
128123
assert.NoError(t, f.executableCheck(sql))
129-
}
124+
}*/

builder_select.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ func Select(cols ...string) *Builder {
1515
}
1616

1717
func (b *Builder) selectWriteTo(w Writer) error {
18-
if len(b.tableName) <= 0 && !b.isNested {
18+
if len(b.from) <= 0 && !b.isNested {
1919
return ErrNoTableName
2020
}
2121

@@ -46,11 +46,11 @@ func (b *Builder) selectWriteTo(w Writer) error {
4646
}
4747

4848
if b.subQuery == nil {
49-
if _, err := fmt.Fprint(w, " FROM ", b.tableName); err != nil {
49+
if _, err := fmt.Fprint(w, " FROM ", b.from); err != nil {
5050
return err
5151
}
5252
} else {
53-
if b.cond.IsValid() && len(b.tableName) <= 0 {
53+
if b.cond.IsValid() && len(b.from) <= 0 {
5454
return ErrUnnamedDerivedTable
5555
}
5656
if b.subQuery.dialect != "" && b.dialect != b.subQuery.dialect {
@@ -69,10 +69,10 @@ func (b *Builder) selectWriteTo(w Writer) error {
6969
return err
7070
}
7171

72-
if len(b.tableName) == 0 {
72+
if len(b.from) == 0 {
7373
fmt.Fprintf(w, ")")
7474
} else {
75-
fmt.Fprintf(w, ") %v", b.tableName)
75+
fmt.Fprintf(w, ") %v", b.from)
7676
}
7777
default:
7878
return ErrUnexpectedSubQuery

builder_select_test.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ func TestBuilder_Select(t *testing.T) {
1515
sql, args, err := Select("c, d").From("table1").ToSQL()
1616
assert.NoError(t, err)
1717
assert.EqualValues(t, "SELECT c, d FROM table1", sql)
18+
assert.EqualValues(t, []interface{}(nil), args)
1819

1920
sql, args, err = Select("c, d").From("table1").Where(Eq{"a": 1}).ToSQL()
2021
assert.NoError(t, err)
@@ -104,24 +105,24 @@ func TestBuilder_From(t *testing.T) {
104105
assert.EqualValues(t, []interface{}{1, 2, 1}, args)
105106

106107
// from union without alias
107-
sql, args, err = Select("sub.id").From(
108+
_, _, err = Select("sub.id").From(
108109
Select("id").From("table1").Where(Eq{"a": 1}).Union(
109110
"all", Select("id").From("table1").Where(Eq{"a": 2}))).Where(Eq{"b": 1}).ToSQL()
110111
assert.Error(t, err)
111112
assert.EqualValues(t, ErrUnnamedDerivedTable, err)
112113

113114
// will raise error
114-
sql, args, err = Select("c").From(Insert(Eq{"a": 1}).From("table1"), "table1").ToSQL()
115+
_, _, err = Select("c").From(Insert(Eq{"a": 1}).From("table1"), "table1").ToSQL()
115116
assert.Error(t, err)
116117
assert.EqualValues(t, ErrUnexpectedSubQuery, err)
117118

118119
// will raise error
119-
sql, args, err = Select("c").From(Delete(Eq{"a": 1}).From("table1"), "table1").ToSQL()
120+
_, _, err = Select("c").From(Delete(Eq{"a": 1}).From("table1"), "table1").ToSQL()
120121
assert.Error(t, err)
121122
assert.EqualValues(t, ErrUnexpectedSubQuery, err)
122123

123124
// from a sub-query in different dialect
124-
sql, args, err = MySQL().Select("sub.id").From(
125+
_, _, err = MySQL().Select("sub.id").From(
125126
Oracle().Select("id").From("table1").Where(Eq{"a": 1}), "sub").Where(Eq{"b": 1}).ToSQL()
126127
assert.Error(t, err)
127128
assert.EqualValues(t, ErrInconsistentDialect, err)

builder_test.go

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -596,24 +596,6 @@ func TestBuilderCond(t *testing.T) {
596596
}
597597
}
598598

599-
func TestBuilderInsert(t *testing.T) {
600-
sql, err := Insert(Eq{"c": 1, "d": 2}).Into("table1").ToBoundSQL()
601-
assert.NoError(t, err)
602-
assert.EqualValues(t, "INSERT INTO table1 (c,d) Values (1,2)", sql)
603-
604-
sql, err = Insert(Eq{"c": 1, "d": Expr("SELECT b FROM t WHERE d=? LIMIT 1", 2)}).Into("table1").ToBoundSQL()
605-
assert.NoError(t, err)
606-
assert.EqualValues(t, "INSERT INTO table1 (c,d) Values (1,(SELECT b FROM t WHERE d=2 LIMIT 1))", sql)
607-
608-
sql, err = Insert(Eq{"c": 1, "d": 2}).ToBoundSQL()
609-
assert.Error(t, err)
610-
assert.EqualValues(t, ErrNoTableName, err)
611-
612-
sql, err = Insert(Eq{}).Into("table1").ToBoundSQL()
613-
assert.Error(t, err)
614-
assert.EqualValues(t, ErrNoColumnToInsert, err)
615-
}
616-
617599
func TestSubquery(t *testing.T) {
618600
subb := Select("id").From("table_b").Where(Eq{"b": "a"})
619601
b := Select("a, b").From("table_a").Where(

0 commit comments

Comments
 (0)