Commit 12ceaf03 authored by Amos Wenger's avatar Amos Wenger

Fork sqliteutil to support booleans + support squash in select

parent 114fc201
Pipeline #10983 passed with stage
in 33 seconds
......@@ -7,7 +7,7 @@ import (
"time"
"crawshaw.io/sqlite"
"crawshaw.io/sqlite/sqliteutil"
"github.com/itchio/hades/sqliteutil2"
"github.com/pkg/errors"
)
......@@ -32,7 +32,7 @@ func (c *Context) syncTable(conn *sqlite.Conn, ms *ModelStruct) (err error) {
}
// migrate table in transaction
defer sqliteutil.Save(conn)(&err)
defer sqliteutil2.Save(conn)(&err)
err = c.ExecRaw(conn, "PRAGMA foreign_keys = 0", nil)
if err != nil {
......
......@@ -224,6 +224,7 @@ func Test_AutoMigrateSquash(t *testing.T) {
c.Log = true
ordie(c.AutoMigrate(conn))
defer c.ExecRaw(conn, "DROP TABLE androids", nil)
pti, err := c.PragmaTableInfo(conn, "androids")
ordie(err)
......
......@@ -4,8 +4,8 @@ import (
"time"
"crawshaw.io/sqlite"
"crawshaw.io/sqlite/sqliteutil"
"github.com/go-xorm/builder"
"github.com/itchio/hades/sqliteutil2"
"github.com/pkg/errors"
)
......@@ -27,7 +27,7 @@ func (c *Context) ExecRaw(conn *sqlite.Conn, query string, resultFn ResultFn, ar
startTime = time.Now()
}
err := sqliteutil.Exec(conn, query, resultFn, args...)
err := sqliteutil2.Exec(conn, query, resultFn, args...)
if c.Log {
c.Consumer.Debugf("[%s] %s %+v", time.Since(startTime), query, args)
......
......@@ -5,9 +5,9 @@ import (
"reflect"
"github.com/go-xorm/builder"
"github.com/itchio/hades/sqliteutil2"
"crawshaw.io/sqlite"
"crawshaw.io/sqlite/sqliteutil"
"github.com/pkg/errors"
)
......@@ -32,7 +32,7 @@ func (c *Context) SaveOne(conn *sqlite.Conn, record interface{}) (err error) {
}
func (c *Context) Save(conn *sqlite.Conn, params *SaveParams) (err error) {
defer sqliteutil.Save(conn)(&err)
defer sqliteutil2.Save(conn)(&err)
return c.SaveNoTransaction(conn, params)
}
......
......@@ -9,8 +9,20 @@ import (
)
func (c *Context) Scan(stmt *sqlite.Stmt, structFields []*StructField, result reflect.Value) error {
for i, sf := range structFields {
i := 0
var processField func(sf *StructField, result reflect.Value) error
processField = func(sf *StructField, result reflect.Value) error {
field := result.FieldByName(sf.Name)
if sf.IsSquashed {
for _, nsf := range sf.SquashedFields {
err := processField(nsf, field)
if err != nil {
return err
}
}
return nil
}
fieldEl := field
typ := field.Type()
......@@ -22,7 +34,8 @@ func (c *Context) Scan(stmt *sqlite.Stmt, structFields []*StructField, result re
wasPtr = true
if colTyp == sqlite.SQLITE_NULL {
field.Set(reflect.Zero(field.Type()))
continue
i++
return nil
}
fieldEl = field.Elem()
......@@ -77,6 +90,17 @@ func (c *Context) Scan(stmt *sqlite.Stmt, structFields []*StructField, result re
default:
return errors.Errorf("For model %s, unknown kind %s for field %s", result.Type(), field.Type().Kind(), sf.Name)
}
i++
return nil
}
for _, sf := range structFields {
err := processField(sf, result)
if err != nil {
return err
}
}
return nil
}
......@@ -82,12 +82,27 @@ func (c *Context) SelectOne(conn *sqlite.Conn, result interface{}, cond builder.
func (c *Context) selectFields(ms *ModelStruct) ([]string, []*StructField) {
var columns []string
var fields []*StructField
for _, sf := range ms.StructFields {
var processField func(sf *StructField, nested bool)
processField = func(sf *StructField, nested bool) {
if sf.IsSquashed {
fields = append(fields, sf)
for _, nsf := range sf.SquashedFields {
processField(nsf, true)
}
}
if !sf.IsNormal {
continue
return
}
columns = append(columns, fmt.Sprintf(`%s.%s`, EscapeIdentifier(ms.TableName), EscapeIdentifier(sf.DBName)))
fields = append(fields, sf)
if !nested {
fields = append(fields, sf)
}
}
for _, sf := range ms.StructFields {
processField(sf, false)
}
return columns, fields
......
......@@ -97,3 +97,62 @@ func Test_Select(t *testing.T) {
err = c.SelectOne(conn, nam, builder.Eq{"id": 3})
assert.Error(t, err, "SelectOne must reject pointer to non-struct")
}
func Test_SelectSquashed(t *testing.T) {
consumer := &state.Consumer{
OnMessage: func(lvl string, message string) {
t.Logf("[%s] %s", lvl, message)
},
}
type AndroidTraits struct {
Wise bool
Funny bool
}
type Android struct {
ID int64
Traits AndroidTraits `hades:"squash"`
}
c, err := hades.NewContext(consumer, &Android{})
if err != nil {
panic(err)
}
c.Log = true
sqlite.Logger = func(code sqlite.ErrorCode, msg []byte) {
t.Logf("[SQLITE] %d %s", code, string(msg))
}
dbpool, err := sqlite.Open("file:memory:?mode=memory", 0, 10)
if err != nil {
panic(err)
}
conn := dbpool.Get(context.Background().Done())
defer dbpool.Put(conn)
wtest.Must(t, c.ExecRaw(conn, "CREATE TABLE androids (id INTEGER PRIMARY KEY, wise BOOLEAN, funny BOOLEAN)", nil))
defer c.ExecRaw(conn, "DROP TABLE androids", nil)
baseAndroids := []Android{
Android{ID: 1, Traits: AndroidTraits{Wise: true}},
Android{ID: 2, Traits: AndroidTraits{Funny: true}},
Android{ID: 3},
Android{ID: 4, Traits: AndroidTraits{Wise: true, Funny: true}},
}
for _, a := range baseAndroids {
wtest.Must(t, c.Exec(conn, builder.Insert(builder.Eq{"id": a.ID, "wise": a.Traits.Wise, "funny": a.Traits.Funny}).Into("androids"), nil))
}
count, err := c.Count(conn, &Android{}, builder.NewCond())
wtest.Must(t, err)
assert.EqualValues(t, 4, count)
a := &Android{}
err = c.SelectOne(conn, a, builder.Eq{"id": 1})
wtest.Must(t, err)
assert.EqualValues(t, baseAndroids[0], *a)
}
// Copyright (c) 2018 David Crawshaw <david@zentus.com>
//
// Permission to use, copy, modify, and distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
package sqliteutil2
import (
"fmt"
"reflect"
"strings"
"crawshaw.io/sqlite"
)
// Exec executes an SQLite query.
//
// For each result row, the resultFn is called.
// Result values can be read by resultFn using stmt.Column* methods.
// If resultFn returns an error then iteration ceases and Exec returns
// the error value.
//
// Any args provided to Exec are bound to numbered parameters of the
// query using the Stmt Bind* methods. Basic reflection on args is used
// to map:
//
// integers to BindInt64
// floats to BindFloat
// []byte to BindBytes
// string to BindText
//
// All other kinds are printed using fmt.Sprintf("%v", v) and passed
// to BindText.
//
// Exec is implemented using the Stmt prepare mechanism which allows
// better interactions with Go's type system and avoids pitfalls of
// passing a Go closure to cgo.
//
// As Exec is implemented using Conn.Prepare, subsequent calls to Exec
// with the same statement will reuse the cached statement object.
//
// Typical use:
//
// conn := dbpool.Get()
// defer dbpool.Put(conn)
//
// if err := sqliteutil2.Exec(conn, "INSERT INTO t (a, b, c, d) VALUES (?, ?, ?, ?);", nil, "a1", 1, 42, 1); err != nil {
// // handle err
// }
//
// var a []string
// var b []int64
// fn := func(stmt *sqlite.Stmt) error {
// a = append(a, stmt.ColumnText(0))
// b = append(b, stmt.ColumnInt64(1))
// return nil
// }
// err := sqlutil.Exec(conn, "SELECT a, b FROM t WHERE c = ? AND d = ?;", fn, 42, 1)
// if err != nil {
// // handle err
// }
func Exec(conn *sqlite.Conn, query string, resultFn func(stmt *sqlite.Stmt) error, args ...interface{}) error {
stmt, err := conn.Prepare(query)
if err != nil {
return annotateErr(err)
}
return exec(stmt, resultFn, args)
}
// ExecTransient executes an SQLite query without caching the
// underlying query.
// The interface is exactly the same as Exec.
//
// It is the spiritual equivalent of sqlite3_exec.
func ExecTransient(conn *sqlite.Conn, query string, resultFn func(stmt *sqlite.Stmt) error, args ...interface{}) (err error) {
var stmt *sqlite.Stmt
var trailingBytes int
stmt, trailingBytes, err = conn.PrepareTransient(query)
if err != nil {
return annotateErr(err)
}
defer func() {
ferr := stmt.Finalize()
if err == nil {
err = ferr
}
}()
if trailingBytes != 0 {
return fmt.Errorf("sqliteutil2.Exec: query %q has trailing bytes", query)
}
return exec(stmt, resultFn, args)
}
func exec(stmt *sqlite.Stmt, resultFn func(stmt *sqlite.Stmt) error, args []interface{}) error {
for i, arg := range args {
i++ // parameters are 1-indexed
v := reflect.ValueOf(arg)
switch v.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
stmt.BindInt64(i, v.Int())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
stmt.BindInt64(i, int64(v.Uint()))
case reflect.Float32, reflect.Float64:
stmt.BindFloat(i, v.Float())
case reflect.String:
stmt.BindText(i, v.String())
case reflect.Invalid:
stmt.BindNull(i)
case reflect.Bool:
stmt.BindBool(i, v.Bool())
default:
if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 {
stmt.BindBytes(i, v.Bytes())
} else {
stmt.BindText(i, fmt.Sprintf("%v", arg))
}
}
}
for {
hasRow, err := stmt.Step()
if err != nil {
return annotateErr(err)
}
if !hasRow {
break
}
if resultFn != nil {
if err := resultFn(stmt); err != nil {
if err, isError := err.(sqlite.Error); isError {
if err.Loc == "" {
err.Loc = "Exec"
} else {
err.Loc = "Exec: " + err.Loc
}
}
// don't modify non-Error errors from resultFn.
return err
}
}
}
return nil
}
func annotateErr(err error) error {
if err, isError := err.(sqlite.Error); isError {
if err.Loc == "" {
err.Loc = "Exec"
} else {
err.Loc = "Exec: " + err.Loc
}
return err
}
return fmt.Errorf("sqlutil.Exec: %v", err)
}
// ExecScript executes a script of SQL statements.
//
// The script is wrapped in a SAVEPOINT transaction,
// which is rolled back on any error.
func ExecScript(conn *sqlite.Conn, queries string) (err error) {
defer Save(conn)(&err)
for {
queries = strings.TrimSpace(queries)
if queries == "" {
break
}
var stmt *sqlite.Stmt
var trailingBytes int
stmt, trailingBytes, err = conn.PrepareTransient(queries)
if err != nil {
return err
}
usedBytes := len(queries) - trailingBytes
queries = queries[usedBytes:]
_, err := stmt.Step()
stmt.Finalize()
if err != nil {
return err
}
}
return nil
}
// Copyright (c) 2018 David Crawshaw <david@zentus.com>
//
// Permission to use, copy, modify, and distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
package sqliteutil2_test
import (
"fmt"
"reflect"
"testing"
"crawshaw.io/sqlite"
"github.com/itchio/hades/sqliteutil2"
)
func TestExec(t *testing.T) {
conn, err := sqlite.OpenConn(":memory:", 0)
if err != nil {
t.Fatal(err)
}
defer conn.Close()
if err := sqliteutil2.ExecTransient(conn, "CREATE TABLE t (a TEXT, b INTEGER);", nil); err != nil {
t.Fatal(err)
}
if err := sqliteutil2.Exec(conn, "INSERT INTO t (a, b) VALUES (?, ?);", nil, "a1", 1); err != nil {
t.Error(err)
}
if err := sqliteutil2.Exec(conn, "INSERT INTO t (a, b) VALUES (?, ?);", nil, "a2", 2); err != nil {
t.Error(err)
}
var a []string
var b []int64
fn := func(stmt *sqlite.Stmt) error {
a = append(a, stmt.ColumnText(0))
b = append(b, stmt.ColumnInt64(1))
return nil
}
if err := sqliteutil2.ExecTransient(conn, "SELECT a, b FROM t;", fn); err != nil {
t.Fatal(err)
}
if want := []string{"a1", "a2"}; !reflect.DeepEqual(a, want) {
t.Errorf("a=%v, want %v", a, want)
}
if want := []int64{1, 2}; !reflect.DeepEqual(b, want) {
t.Errorf("b=%v, want %v", b, want)
}
}
func TestExecErr(t *testing.T) {
conn, err := sqlite.OpenConn(":memory:", 0)
if err != nil {
t.Fatal(err)
}
defer conn.Close()
err = sqliteutil2.Exec(conn, "INVALID SQL STMT", nil)
if err == nil {
t.Error("invalid SQL did not return an error code")
}
if got, want := sqlite.ErrCode(err), sqlite.SQLITE_ERROR; got != want {
t.Errorf("INVALID err code=%s, want %s", got, want)
}
if err := sqliteutil2.Exec(conn, "CREATE TABLE t (c1, c2);", nil); err != nil {
t.Error(err)
}
if err := sqliteutil2.Exec(conn, "INSERT INTO t (c1, c2) VALUES (?, ?);", nil, 1, 1); err != nil {
t.Error(err)
}
if err := sqliteutil2.Exec(conn, "INSERT INTO t (c1, c2) VALUES (?, ?);", nil, 2, 2); err != nil {
t.Error(err)
}
err = sqliteutil2.Exec(conn, "INSERT INTO t (c1, c2) VALUES (?, ?);", nil, 1, 1, 1)
if got, want := sqlite.ErrCode(err), sqlite.SQLITE_RANGE; got != want {
t.Errorf("INSERT err code=%s, want %s", got, want)
}
calls := 0
customErr := fmt.Errorf("custom err")
fn := func(stmt *sqlite.Stmt) error {
calls++
return customErr
}
err = sqliteutil2.Exec(conn, "SELECT c1 FROM t;", fn)
if err != customErr {
t.Errorf("SELECT want err=customErr, got: %v", err)
}
if calls != 1 {
t.Errorf("SELECT want truncated callback calls, got calls=%d", calls)
}
}
func TestExecScript(t *testing.T) {
conn, err := sqlite.OpenConn(":memory:", 0)
if err != nil {
t.Fatal(err)
}
defer conn.Close()
script := `
CREATE TABLE t (a TEXT, b INTEGER);
INSERT INTO t (a, b) VALUES ("a1", 1);
INSERT INTO t (a, b) VALUES ("a2", 2);
`
if err := sqliteutil2.ExecScript(conn, script); err != nil {
t.Error(err)
}
sum := 0
fn := func(stmt *sqlite.Stmt) error {
sum = stmt.ColumnInt(0)
return nil
}
if err := sqliteutil2.Exec(conn, "SELECT sum(b) FROM t;", fn); err != nil {
t.Fatal(err)
}
if sum != 3 {
t.Errorf("sum=%d, want 3", sum)
}
}
// Copyright (c) 2018 David Crawshaw <david@zentus.com>
//
// Permission to use, copy, modify, and distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
package sqliteutil2
import (
"fmt"
"runtime"
"strings"
"crawshaw.io/sqlite"
)
// Save creates a named SQLite transaction using SAVEPOINT.
//
// On success Savepoint returns a releaseFn that will call either
// RELEASE or ROLLBACK depending on whether the parameter *error
// points to a nil or non-nil error. This is designed to be deferred.
//
// Example:
//
// func doWork(conn *sqlite.Conn) (err error) {
// defer sqliteutil2.Save(conn)(&err)
//
// // ... do work in the transaction
// }
//
// https://www.sqlite.org/lang_savepoint.html
func Save(conn *sqlite.Conn) (releaseFn func(*error)) {
name := "sqliteutil2.Save" // safe as names can be reused
var pc [3]uintptr
if n := runtime.Callers(0, pc[:]); n > 0 {
frames := runtime.CallersFrames(pc[:n])
if _, more := frames.Next(); more { // runtime.Callers
if _, more := frames.Next(); more { // savepoint.Save
frame, _ := frames.Next() // caller we care about
if frame.Function != "" {
name = frame.Function
}
}
}
}
releaseFn, err := savepoint(conn, name)
if err != nil {
if sqlite.ErrCode(err) == sqlite.SQLITE_INTERRUPT {
return func(errp *error) {
if *errp == nil {
*errp = err
}
}
}
panic(err)
}
return releaseFn
}
func savepoint(conn *sqlite.Conn, name string) (releaseFn func(*error), err error) {
if strings.Contains(name, `"`) {
return nil, fmt.Errorf("sqliteutil2.Savepoint: invalid name: %q", name)
}
if err := Exec(conn, fmt.Sprintf("SAVEPOINT %q;", name), nil); err != nil {
return nil, err
}
releaseFn = func(errp *error) {
if p := recover(); p != nil {
Exec(conn, fmt.Sprintf("ROLLBACK TO %q;", name), nil)
panic(p)
}
if *errp == nil {
*errp = Exec(conn, fmt.Sprintf("RELEASE %q;", name), nil)
} else {
err := Exec(conn, fmt.Sprintf("ROLLBACK TO %q;", name), nil)
if err != nil {
panic(err)
}
err = Exec(conn, fmt.Sprintf("RELEASE %q;", name), nil)
if err != nil {
panic(err)
}
}
}
return releaseFn, nil
}
// Copyright (c) 2018 David Crawshaw <david@zentus.com>
//
// Permission to use, copy, modify, and distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
package sqliteutil2
import (
"errors"
"strings"
"testing"
"crawshaw.io/sqlite"
)
func TestExec(t *testing.T) {
conn, err := sqlite.OpenConn(":memory:", 0)
if err != nil {
t.Fatal(err)
}
defer conn.Close()
if err := Exec(conn, "CREATE TABLE t (c1);", nil); err != nil {
t.Fatal(err)
}
countFn := func() int {
var count int
fn := func(stmt *sqlite.Stmt) error {
count = stmt.ColumnInt(0)
return nil
}
if err := Exec(conn, "SELECT count(*) FROM t;", fn); err != nil {
t.Fatal(err)
}
return count
}
errNoSuccess := errors.New("succeed=false")
insert := func(succeed bool) (err error) {
defer Save(conn)(&err)
if err := Exec(conn, `INSERT INTO t VALUES ("hello");`, nil); err != nil {
t.Fatal(err)
}
if succeed {