Skip to content
This repository was archived by the owner on Jun 14, 2019. It is now read-only.

Commit 6e0d98f

Browse files
BetaCat0lunny
authored andcommitted
Update From & more tests (#35)
* add more test cases & replace some err in errors.New(...) style & add some built-in err types * update From & related test cases * add test cases & fix bug * fix issues * update readme * replace postgres placeholder * Update From & add more tests * fix issues * compatibility with sql.NamedArg & ref test case * compatibility with sql.NamedArg & ref test case * update bench test * enhance placeholder: in case of sql.NamedArg in args * enhance ToSQL: compatible with different sql drivers
1 parent 145c996 commit 6e0d98f

9 files changed

+179
-18
lines changed

builder.go

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44

55
package builder
66

7+
import (
8+
sql2 "database/sql"
9+
"fmt"
10+
)
11+
712
type optype byte
813

914
const (
@@ -64,6 +69,31 @@ func Dialect(dialect string) *Builder {
6469
return builder
6570
}
6671

72+
// MySQL is shortcut of Dialect(MySQL)
73+
func MySQL() *Builder {
74+
return Dialect(MYSQL)
75+
}
76+
77+
// MsSQL is shortcut of Dialect(MsSQL)
78+
func MsSQL() *Builder {
79+
return Dialect(MSSQL)
80+
}
81+
82+
// Oracle is shortcut of Dialect(Oracle)
83+
func Oracle() *Builder {
84+
return Dialect(ORACLE)
85+
}
86+
87+
// Postgres is shortcut of Dialect(Postgres)
88+
func Postgres() *Builder {
89+
return Dialect(POSTGRES)
90+
}
91+
92+
// SQLite is shortcut of Dialect(SQLITE)
93+
func SQLite() *Builder {
94+
return Dialect(SQLITE)
95+
}
96+
6797
// Where sets where SQL
6898
func (b *Builder) Where(cond Cond) *Builder {
6999
if b.cond.IsValid() {
@@ -142,7 +172,10 @@ func (b *Builder) Union(unionTp string, unionCond *Builder) *Builder {
142172
}
143173

144174
if unionCond != nil {
145-
unionCond.dialect = builder.dialect
175+
if unionCond.dialect == "" && builder.dialect != "" {
176+
unionCond.dialect = builder.dialect
177+
}
178+
146179
builder.unions = append(builder.unions, union{unionTp, unionCond})
147180
}
148181

@@ -257,7 +290,40 @@ func (b *Builder) ToSQL() (string, []interface{}, error) {
257290
return "", nil, err
258291
}
259292

260-
return w.writer.String(), w.args, nil
293+
// in case of sql.NamedArg in args
294+
for e := range w.args {
295+
if namedArg, ok := w.args[e].(sql2.NamedArg); ok {
296+
w.args[e] = namedArg.Value
297+
}
298+
}
299+
300+
var sql = w.writer.String()
301+
var err error
302+
303+
switch b.dialect {
304+
case ORACLE, MSSQL:
305+
// This is for compatibility with different sql drivers
306+
for e := range w.args {
307+
w.args[e] = sql2.Named(fmt.Sprintf("p%d", e+1), w.args[e])
308+
}
309+
310+
var prefix string
311+
if b.dialect == ORACLE {
312+
prefix = ":p"
313+
} else {
314+
prefix = "@p"
315+
}
316+
317+
if sql, err = ConvertPlaceholder(sql, prefix); err != nil {
318+
return "", nil, err
319+
}
320+
case POSTGRES:
321+
if sql, err = ConvertPlaceholder(sql, "$"); err != nil {
322+
return "", nil, err
323+
}
324+
}
325+
326+
return sql, w.args, nil
261327
}
262328

263329
// ToBoundSQL

builder_b_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ func randUpdateByCondition(rgc *randGenConf) *Builder {
187187
b := Update(eqs).From("table1")
188188

189189
if rgc.allowCond && rand.Intn(1000) >= 500 {
190-
b = b.Where(randCond(b.selects, 3))
190+
b.Where(randCond(fields, 3))
191191
}
192192

193193
return b

builder_select.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ func (b *Builder) selectWriteTo(w Writer) error {
5353
if b.cond.IsValid() && len(b.tableName) <= 0 {
5454
return ErrUnnamedDerivedTable
5555
}
56+
if b.subQuery.dialect != "" && b.dialect != b.subQuery.dialect {
57+
return ErrInconsistentDialect
58+
}
59+
60+
// dialect of sub-query will inherit from the main one (if not set up)
61+
if b.dialect != "" && b.subQuery.dialect == "" {
62+
b.subQuery.dialect = b.dialect
63+
}
5664

5765
switch b.subQuery.optype {
5866
case selectType, unionType:

builder_select_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,4 +119,24 @@ func TestBuilder_From(t *testing.T) {
119119
sql, args, err = Select("c").From(Delete(Eq{"a": 1}).From("table1"), "table1").ToSQL()
120120
assert.Error(t, err)
121121
assert.EqualValues(t, ErrUnexpectedSubQuery, err)
122+
123+
// from a sub-query in different dialect
124+
sql, args, err = MySQL().Select("sub.id").From(
125+
Oracle().Select("id").From("table1").Where(Eq{"a": 1}), "sub").Where(Eq{"b": 1}).ToSQL()
126+
assert.Error(t, err)
127+
assert.EqualValues(t, ErrInconsistentDialect, err)
128+
129+
// from a sub-query (dialect set up)
130+
sql, args, err = MySQL().Select("sub.id").From(
131+
MySQL().Select("id").From("table1").Where(Eq{"a": 1}), "sub").Where(Eq{"b": 1}).ToSQL()
132+
assert.NoError(t, err)
133+
assert.EqualValues(t, "SELECT sub.id FROM (SELECT id FROM table1 WHERE a=?) sub WHERE b=?", sql)
134+
assert.EqualValues(t, []interface{}{1, 1}, args)
135+
136+
// from a sub-query (dialect not set up)
137+
sql, args, err = MySQL().Select("sub.id").From(
138+
Select("id").From("table1").Where(Eq{"a": 1}), "sub").Where(Eq{"b": 1}).ToSQL()
139+
assert.NoError(t, err)
140+
assert.EqualValues(t, "SELECT sub.id FROM (SELECT id FROM table1 WHERE a=?) sub WHERE b=?", sql)
141+
assert.EqualValues(t, []interface{}{1, 1}, args)
122142
}

builder_union.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ func (b *Builder) unionWriteTo(w Writer) error {
2626
return err
2727
}
2828
} else {
29+
if b.dialect != "" && b.dialect != current.dialect {
30+
return ErrInconsistentDialect
31+
}
32+
2933
if idx != 0 {
3034
fmt.Fprint(w, fmt.Sprintf(" UNION %v ", strings.ToUpper(u.unionType)))
3135
}

builder_union_test.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,22 @@ func TestBuilder_Union(t *testing.T) {
2121
assert.EqualValues(t, "(SELECT * FROM t1 WHERE status=?) UNION ALL (SELECT * FROM t2 WHERE status=?) UNION DISTINCT (SELECT * FROM t2 WHERE status=?) UNION (SELECT * FROM t2 WHERE status=?)", sql)
2222
assert.EqualValues(t, []interface{}{"1", "2", "3", "3"}, args)
2323

24+
// sub-query will inherit dialect from the main one
25+
sql, args, err = MySQL().Select("*").From("t1").Where(Eq{"status": "1"}).
26+
Union("all", Select("*").From("t2").Where(Eq{"status": "2"}).Limit(10)).
27+
Union("", Select("*").From("t2").Where(Eq{"status": "3"})).
28+
ToSQL()
29+
assert.NoError(t, err)
30+
assert.EqualValues(t, "(SELECT * FROM t1 WHERE status=?) UNION ALL (SELECT * FROM t2 WHERE status=? LIMIT 10) UNION (SELECT * FROM t2 WHERE status=?)", sql)
31+
assert.EqualValues(t, []interface{}{"1", "2", "3"}, args)
32+
33+
// will raise error
34+
sql, args, err = MySQL().Select("*").From("t1").Where(Eq{"status": "1"}).
35+
Union("all", Oracle().Select("*").From("t2").Where(Eq{"status": "2"}).Limit(10)).
36+
ToSQL()
37+
assert.Error(t, err)
38+
assert.EqualValues(t, ErrInconsistentDialect, err)
39+
2440
// will raise error
2541
sql, args, err = Select("*").From("table1").Where(Eq{"a": "1"}).
2642
Union("all", Select("*").From("table2").Where(Eq{"a": "2"})).

error.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,6 @@ var (
3535
ErrInvalidLimitation = errors.New("Offset or limit is not correct")
3636
// ErrUnnamedDerivedTable Every derived table must have its own alias
3737
ErrUnnamedDerivedTable = errors.New("Every derived table must have its own alias")
38+
// ErrInconsistentDialect Inconsistent dialect in same builder
39+
ErrInconsistentDialect = errors.New("Inconsistent dialect in same builder")
3840
)

sql.go

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package builder
66

77
import (
8+
sql2 "database/sql"
89
"fmt"
910
"reflect"
1011
"time"
@@ -105,10 +106,15 @@ func ConvertToBoundSQL(sql string, args []interface{}) (string, error) {
105106
return "", ErrNeedMoreArguments
106107
}
107108

108-
if noSQLQuoteNeeded(args[j]) {
109-
_, err = fmt.Fprint(&buf, args[j])
109+
arg := args[j]
110+
if namedArg, ok := arg.(sql2.NamedArg); ok {
111+
arg = namedArg.Value
112+
}
113+
114+
if noSQLQuoteNeeded(arg) {
115+
_, err = fmt.Fprint(&buf, arg)
110116
} else {
111-
_, err = fmt.Fprintf(&buf, "'%v'", args[j])
117+
_, err = fmt.Fprintf(&buf, "'%v'", arg)
112118
}
113119
if err != nil {
114120
return "", err
@@ -129,27 +135,22 @@ func ConvertPlaceholder(sql, prefix string) (string, error) {
129135
var i, j, start int
130136
for ; i < len(sql); i++ {
131137
if sql[i] == '?' {
132-
_, err := buf.WriteString(sql[start:i])
133-
if err != nil {
134-
return "", err
135-
}
136-
start = i + 1
137-
138-
_, err = buf.WriteString(prefix)
139-
if err != nil {
138+
if _, err := buf.WriteString(sql[start:i]); err != nil {
140139
return "", err
141140
}
142141

142+
start = i + 1
143143
j = j + 1
144-
_, err = buf.WriteString(fmt.Sprintf("%d", j))
145-
if err != nil {
144+
145+
if _, err := buf.WriteString(fmt.Sprintf("%v%d", prefix, j)); err != nil {
146146
return "", err
147147
}
148148
}
149149
}
150-
_, err := buf.WriteString(sql[start:])
151-
if err != nil {
150+
151+
if _, err := buf.WriteString(sql[start:]); err != nil {
152152
return "", err
153153
}
154+
154155
return buf.String(), nil
155156
}

sql_test.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package builder
66

77
import (
8+
sql2 "database/sql"
89
"fmt"
910
"io/ioutil"
1011
"os"
@@ -35,6 +36,10 @@ func TestBoundSQLConverter(t *testing.T) {
3536
assert.NoError(t, err)
3637
assert.EqualValues(t, placeholderBoundSQL, newSQL)
3738

39+
newSQL, err = ConvertToBoundSQL(placeholderConverterSQL, []interface{}{1, 2.1, sql2.Named("any", "3"), uint(4), "5", true})
40+
assert.NoError(t, err)
41+
assert.EqualValues(t, placeholderBoundSQL, newSQL)
42+
3843
newSQL, err = ConvertToBoundSQL(placeholderConverterSQL, []interface{}{1, 2.1, "3", 4, "5"})
3944
assert.Error(t, err)
4045
assert.EqualValues(t, ErrNeedMoreArguments, err)
@@ -162,3 +167,42 @@ func TestExecutableCheck(t *testing.T) {
162167
err = f.executableCheck("SELECT * FROM table3")
163168
assert.Error(t, err)
164169
}
170+
171+
func TestToSQLInDifferentDialects(t *testing.T) {
172+
sql, args, err := Postgres().Select().From("table1").Where(Eq{"a": "1"}.And(Neq{"b": "100"})).ToSQL()
173+
assert.NoError(t, err)
174+
assert.EqualValues(t, "SELECT * FROM table1 WHERE a=$1 AND b<>$2", sql)
175+
assert.EqualValues(t, []interface{}{"1", "100"}, args)
176+
177+
sql, args, err = MySQL().Select().From("table1").Where(Eq{"a": "1"}.And(Neq{"b": "100"})).ToSQL()
178+
assert.NoError(t, err)
179+
assert.EqualValues(t, "SELECT * FROM table1 WHERE a=? AND b<>?", sql)
180+
assert.EqualValues(t, []interface{}{"1", "100"}, args)
181+
182+
sql, args, err = MsSQL().Select().From("table1").Where(Eq{"a": "1"}.And(Neq{"b": "100"})).ToSQL()
183+
assert.NoError(t, err)
184+
assert.EqualValues(t, "SELECT * FROM table1 WHERE a=@p1 AND b<>@p2", sql)
185+
assert.EqualValues(t, []interface{}{sql2.Named("p1", "1"), sql2.Named("p2", "100")}, args)
186+
187+
// test sql.NamedArg in cond
188+
sql, args, err = MsSQL().Select().From("table1").Where(Eq{"a": sql2.NamedArg{Name: "param", Value: "1"}}.And(Neq{"b": "100"})).ToSQL()
189+
assert.NoError(t, err)
190+
assert.EqualValues(t, "SELECT * FROM table1 WHERE a=@p1 AND b<>@p2", sql)
191+
assert.EqualValues(t, []interface{}{sql2.Named("p1", "1"), sql2.Named("p2", "100")}, args)
192+
193+
sql, args, err = Oracle().Select().From("table1").Where(Eq{"a": "1"}.And(Neq{"b": "100"})).ToSQL()
194+
assert.NoError(t, err)
195+
assert.EqualValues(t, "SELECT * FROM table1 WHERE a=:p1 AND b<>:p2", sql)
196+
assert.EqualValues(t, []interface{}{sql2.Named("p1", "1"), sql2.Named("p2", "100")}, args)
197+
198+
// test sql.NamedArg in cond
199+
sql, args, err = Oracle().Select().From("table1").Where(Eq{"a": sql2.Named("a", "1")}.And(Neq{"b": "100"})).ToSQL()
200+
assert.NoError(t, err)
201+
assert.EqualValues(t, "SELECT * FROM table1 WHERE a=:p1 AND b<>:p2", sql)
202+
assert.EqualValues(t, []interface{}{sql2.Named("p1", "1"), sql2.Named("p2", "100")}, args)
203+
204+
sql, args, err = SQLite().Select().From("table1").Where(Eq{"a": "1"}.And(Neq{"b": "100"})).ToSQL()
205+
assert.NoError(t, err)
206+
assert.EqualValues(t, "SELECT * FROM table1 WHERE a=? AND b<>?", sql)
207+
assert.EqualValues(t, []interface{}{"1", "100"}, args)
208+
}

0 commit comments

Comments
 (0)