Commit 9e88aa34 authored by Amos Wenger's avatar Amos Wenger

Switch back to crawshaw.io/sqlite/sqliteutil

parent d68cb367
Pipeline #11365 passed with stage
in 1 minute and 15 seconds
......@@ -7,7 +7,7 @@ import (
"time"
"crawshaw.io/sqlite"
"github.com/itchio/hades/sqliteutil2"
"crawshaw.io/sqlite/sqliteutil"
"github.com/pkg/errors"
)
......@@ -43,7 +43,7 @@ func (c *Context) syncTable(conn *sqlite.Conn, stats *AutoMigrateStats, ms *Mode
}
// migrate table in transaction
defer sqliteutil2.Save(conn)(&err)
defer sqliteutil.Save(conn)(&err)
err = c.ExecRaw(conn, "PRAGMA foreign_keys = 0", nil)
if err != nil {
......
......@@ -6,9 +6,8 @@ import (
"testing"
"time"
"github.com/itchio/hades/sqliteutil2"
"crawshaw.io/sqlite"
"crawshaw.io/sqlite/sqliteutil"
"github.com/go-xorm/builder"
"github.com/itchio/hades"
"github.com/stretchr/testify/assert"
......@@ -284,7 +283,7 @@ func Test_AutoMigratePreservesData(t *testing.T) {
conn := dbpool.Get(context.Background().Done())
defer dbpool.Put(conn)
defer sqliteutil2.Exec(conn, "DROP TABLE androids", nil)
defer sqliteutil.Exec(conn, "DROP TABLE androids", nil)
{
type AndroidTraits struct {
......
......@@ -4,8 +4,8 @@ import (
"time"
"crawshaw.io/sqlite"
"crawshaw.io/sqlite/sqliteutil"
"github.com/go-xorm/builder"
"github.com/itchio/hades/sqliteutil2"
"github.com/pkg/errors"
)
......@@ -36,7 +36,7 @@ func (c *Context) ExecRaw(conn *sqlite.Conn, query string, resultFn ResultFn, ar
startTime = time.Now()
}
err := sqliteutil2.Exec(conn, query, resultFn, args...)
err := sqliteutil.Exec(conn, query, resultFn, args...)
if c.Log {
c.Consumer.Debugf("[%s] %s %+v", time.Since(startTime), query, args)
......
......@@ -5,9 +5,8 @@ import (
"reflect"
"strings"
"github.com/itchio/hades/sqliteutil2"
"crawshaw.io/sqlite"
"crawshaw.io/sqlite/sqliteutil"
"github.com/pkg/errors"
)
......@@ -15,11 +14,8 @@ type AllEntities map[reflect.Type]EntityMap
type EntityMap []interface{}
func (c *Context) Save(conn *sqlite.Conn, rec interface{}, opts ...SaveParam) (err error) {
defer sqliteutil2.Save(conn)(&err)
return c.SaveNoTransaction(conn, rec, opts...)
}
defer sqliteutil.Save(conn)(&err)
func (c *Context) SaveNoTransaction(conn *sqlite.Conn, rec interface{}, opts ...SaveParam) error {
var params saveParams
for _, o := range opts {
o.ApplyToSaveParams(&params)
......
// Copyright (c) 2018 David Crawshaw <david@zentus.com>
//
// Permission to use, copy, modify, and distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
package sqliteutil2
import (
"fmt"
"reflect"
"strings"
"crawshaw.io/sqlite"
"github.com/pkg/errors"
)
// Exec executes an SQLite query.
//
// For each result row, the resultFn is called.
// Result values can be read by resultFn using stmt.Column* methods.
// If resultFn returns an error then iteration ceases and Exec returns
// the error value.
//
// Any args provided to Exec are bound to numbered parameters of the
// query using the Stmt Bind* methods. Basic reflection on args is used
// to map:
//
// integers to BindInt64
// floats to BindFloat
// []byte to BindBytes
// string to BindText
//
// All other kinds are printed using fmt.Sprintf("%v", v) and passed
// to BindText.
//
// Exec is implemented using the Stmt prepare mechanism which allows
// better interactions with Go's type system and avoids pitfalls of
// passing a Go closure to cgo.
//
// As Exec is implemented using Conn.Prepare, subsequent calls to Exec
// with the same statement will reuse the cached statement object.
//
// Typical use:
//
// conn := dbpool.Get()
// defer dbpool.Put(conn)
//
// if err := sqliteutil2.Exec(conn, "INSERT INTO t (a, b, c, d) VALUES (?, ?, ?, ?);", nil, "a1", 1, 42, 1); err != nil {
// // handle err
// }
//
// var a []string
// var b []int64
// fn := func(stmt *sqlite.Stmt) error {
// a = append(a, stmt.ColumnText(0))
// b = append(b, stmt.ColumnInt64(1))
// return nil
// }
// err := sqlutil.Exec(conn, "SELECT a, b FROM t WHERE c = ? AND d = ?;", fn, 42, 1)
// if err != nil {
// // handle err
// }
func Exec(conn *sqlite.Conn, query string, resultFn func(stmt *sqlite.Stmt) error, args ...interface{}) error {
stmt, err := conn.Prepare(query)
if err != nil {
return annotateErr(err)
}
return exec(stmt, resultFn, args)
}
// ExecTransient executes an SQLite query without caching the
// underlying query.
// The interface is exactly the same as Exec.
//
// It is the spiritual equivalent of sqlite3_exec.
func ExecTransient(conn *sqlite.Conn, query string, resultFn func(stmt *sqlite.Stmt) error, args ...interface{}) (err error) {
var stmt *sqlite.Stmt
var trailingBytes int
stmt, trailingBytes, err = conn.PrepareTransient(query)
if err != nil {
return annotateErr(err)
}
defer func() {
ferr := stmt.Finalize()
if err == nil {
err = ferr
}
}()
if trailingBytes != 0 {
return errors.Errorf("sqliteutil2.Exec: query %q has trailing bytes", query)
}
return exec(stmt, resultFn, args)
}
func exec(stmt *sqlite.Stmt, resultFn func(stmt *sqlite.Stmt) error, args []interface{}) error {
for i, arg := range args {
i++ // parameters are 1-indexed
v := reflect.ValueOf(arg)
switch v.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
stmt.BindInt64(i, v.Int())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
stmt.BindInt64(i, int64(v.Uint()))
case reflect.Float32, reflect.Float64:
stmt.BindFloat(i, v.Float())
case reflect.String:
stmt.BindText(i, v.String())
case reflect.Invalid:
stmt.BindNull(i)
case reflect.Bool:
stmt.BindBool(i, v.Bool())
default:
if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 {
stmt.BindBytes(i, v.Bytes())
} else {
stmt.BindText(i, fmt.Sprintf("%v", arg))
}
}
}
for {
hasRow, err := stmt.Step()
if err != nil {
return annotateErr(err)
}
if !hasRow {
break
}
if resultFn != nil {
if err := resultFn(stmt); err != nil {
if err, isError := err.(sqlite.Error); isError {
if err.Loc == "" {
err.Loc = "Exec"
} else {
err.Loc = "Exec: " + err.Loc
}
}
// don't modify non-Error errors from resultFn.
return err
}
}
}
return nil
}
func annotateErr(err error) error {
if err, isError := err.(sqlite.Error); isError {
if err.Loc == "" {
err.Loc = "Exec"
} else {
err.Loc = "Exec: " + err.Loc
}
return errors.WithStack(err)
}
return errors.Errorf("sqlutil.Exec: %v", err)
}
// ExecScript executes a script of SQL statements.
//
// The script is wrapped in a SAVEPOINT transaction,
// which is rolled back on any error.
func ExecScript(conn *sqlite.Conn, queries string) (err error) {
defer Save(conn)(&err)
for {
queries = strings.TrimSpace(queries)
if queries == "" {
break
}
var stmt *sqlite.Stmt
var trailingBytes int
stmt, trailingBytes, err = conn.PrepareTransient(queries)
if err != nil {
return err
}
usedBytes := len(queries) - trailingBytes
queries = queries[usedBytes:]
_, err := stmt.Step()
stmt.Finalize()
if err != nil {
return err
}
}
return nil
}
// Copyright (c) 2018 David Crawshaw <david@zentus.com>
//
// Permission to use, copy, modify, and distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
package sqliteutil2_test
import (
"reflect"
"testing"
"crawshaw.io/sqlite"
"github.com/itchio/hades/sqliteutil2"
"github.com/pkg/errors"
)
func TestExec(t *testing.T) {
conn, err := sqlite.OpenConn(":memory:", 0)
if err != nil {
t.Fatal(err)
}
defer conn.Close()
if err := sqliteutil2.ExecTransient(conn, "CREATE TABLE t (a TEXT, b INTEGER);", nil); err != nil {
t.Fatal(err)
}
if err := sqliteutil2.Exec(conn, "INSERT INTO t (a, b) VALUES (?, ?);", nil, "a1", 1); err != nil {
t.Error(err)
}
if err := sqliteutil2.Exec(conn, "INSERT INTO t (a, b) VALUES (?, ?);", nil, "a2", 2); err != nil {
t.Error(err)
}
var a []string
var b []int64
fn := func(stmt *sqlite.Stmt) error {
a = append(a, stmt.ColumnText(0))
b = append(b, stmt.ColumnInt64(1))
return nil
}
if err := sqliteutil2.ExecTransient(conn, "SELECT a, b FROM t;", fn); err != nil {
t.Fatal(err)
}
if want := []string{"a1", "a2"}; !reflect.DeepEqual(a, want) {
t.Errorf("a=%v, want %v", a, want)
}
if want := []int64{1, 2}; !reflect.DeepEqual(b, want) {
t.Errorf("b=%v, want %v", b, want)
}
}
func TestExecErr(t *testing.T) {
conn, err := sqlite.OpenConn(":memory:", 0)
if err != nil {
t.Fatal(err)
}
defer conn.Close()
err = sqliteutil2.Exec(conn, "INVALID SQL STMT", nil)
if err == nil {
t.Error("invalid SQL did not return an error code")
}
if got, want := sqlite.ErrCode(err), sqlite.SQLITE_ERROR; got != want {
t.Errorf("INVALID err code=%s, want %s", got, want)
}
if err := sqliteutil2.Exec(conn, "CREATE TABLE t (c1, c2);", nil); err != nil {
t.Error(err)
}
if err := sqliteutil2.Exec(conn, "INSERT INTO t (c1, c2) VALUES (?, ?);", nil, 1, 1); err != nil {
t.Error(err)
}
if err := sqliteutil2.Exec(conn, "INSERT INTO t (c1, c2) VALUES (?, ?);", nil, 2, 2); err != nil {
t.Error(err)
}
err = sqliteutil2.Exec(conn, "INSERT INTO t (c1, c2) VALUES (?, ?);", nil, 1, 1, 1)
if got, want := sqlite.ErrCode(errors.Cause(err)), sqlite.SQLITE_RANGE; got != want {
t.Errorf("INSERT err code=%s, want %s", got, want)
}
calls := 0
customErr := errors.Errorf("custom err")
fn := func(stmt *sqlite.Stmt) error {
calls++
return customErr
}
err = sqliteutil2.Exec(conn, "SELECT c1 FROM t;", fn)
if err != customErr {
t.Errorf("SELECT want err=customErr, got: %v", err)
}
if calls != 1 {
t.Errorf("SELECT want truncated callback calls, got calls=%d", calls)
}
}
func TestExecScript(t *testing.T) {
conn, err := sqlite.OpenConn(":memory:", 0)
if err != nil {
t.Fatal(err)
}
defer conn.Close()
script := `
CREATE TABLE t (a TEXT, b INTEGER);
INSERT INTO t (a, b) VALUES ("a1", 1);
INSERT INTO t (a, b) VALUES ("a2", 2);
`
if err := sqliteutil2.ExecScript(conn, script); err != nil {
t.Error(err)
}
sum := 0
fn := func(stmt *sqlite.Stmt) error {
sum = stmt.ColumnInt(0)
return nil
}
if err := sqliteutil2.Exec(conn, "SELECT sum(b) FROM t;", fn); err != nil {
t.Fatal(err)
}
if sum != 3 {
t.Errorf("sum=%d, want 3", sum)
}
}
// Copyright (c) 2018 David Crawshaw <david@zentus.com>
//
// Permission to use, copy, modify, and distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
package sqliteutil2
import (
"fmt"
"runtime"
"strings"
"crawshaw.io/sqlite"
"github.com/pkg/errors"
)
// Save creates a named SQLite transaction using SAVEPOINT.
//
// On success Savepoint returns a releaseFn that will call either
// RELEASE or ROLLBACK depending on whether the parameter *error
// points to a nil or non-nil error. This is designed to be deferred.
//
// Example:
//
// func doWork(conn *sqlite.Conn) (err error) {
// defer sqliteutil2.Save(conn)(&err)
//
// // ... do work in the transaction
// }
//
// https://www.sqlite.org/lang_savepoint.html
func Save(conn *sqlite.Conn) (releaseFn func(*error)) {
name := "sqliteutil2.Save" // safe as names can be reused
var pc [3]uintptr
if n := runtime.Callers(0, pc[:]); n > 0 {
frames := runtime.CallersFrames(pc[:n])
if _, more := frames.Next(); more { // runtime.Callers
if _, more := frames.Next(); more { // savepoint.Save
frame, _ := frames.Next() // caller we care about
if frame.Function != "" {
name = frame.Function
}
}
}
}
releaseFn, err := savepoint(conn, name)
if err != nil {
if sqlite.ErrCode(errors.Cause(err)) == sqlite.SQLITE_INTERRUPT {
return func(errp *error) {
if *errp == nil {
*errp = err
}
}
}
panic(err)
}
return releaseFn
}
func savepoint(conn *sqlite.Conn, name string) (releaseFn func(*error), err error) {
if strings.Contains(name, `"`) {
return nil, errors.Errorf("sqliteutil2.Savepoint: invalid name: %q", name)
}
if err := Exec(conn, fmt.Sprintf("SAVEPOINT %q;", name), nil); err != nil {
return nil, err
}
releaseFn = func(errp *error) {
if p := recover(); p != nil {
Exec(conn, fmt.Sprintf("ROLLBACK TO %q;", name), nil)
panic(p)
}
if *errp == nil {
*errp = Exec(conn, fmt.Sprintf("RELEASE %q;", name), nil)
} else {
err := Exec(conn, fmt.Sprintf("ROLLBACK TO %q;", name), nil)
if err != nil {
panic(err)
}
err = Exec(conn, fmt.Sprintf("RELEASE %q;", name), nil)
if err != nil {
panic(err)
}
}
}
return releaseFn, nil
}
// Copyright (c) 2018 David Crawshaw <david@zentus.com>
//
// Permission to use, copy, modify, and distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
package sqliteutil2
import (
"strings"
"testing"
"github.com/pkg/errors"
"crawshaw.io/sqlite"
)
func TestExec(t *testing.T) {
conn, err := sqlite.OpenConn(":memory:", 0)
if err != nil {
t.Fatal(err)
}
defer conn.Close()
if err := Exec(conn, "CREATE TABLE t (c1);", nil); err != nil {
t.Fatal(err)
}
countFn := func() int {
var count int
fn := func(stmt *sqlite.Stmt) error {
count = stmt.ColumnInt(0)
return nil
}
if err := Exec(conn, "SELECT count(*) FROM t;", fn); err != nil {
t.Fatal(err)
}
return count
}
errNoSuccess := errors.New("succeed=false")
insert := func(succeed bool) (err error) {
defer Save(conn)(&err)
if err := Exec(conn, `INSERT INTO t VALUES ("hello");`, nil); err != nil {
t.Fatal(err)
}
if succeed {
return nil
}
return errNoSuccess
}
if err := insert(true); err != nil {
t.Fatal(err)
}
if got := countFn(); got != 1 {
t.Errorf("expecting 1 row, got %d", got)
}
if err := insert(true); err != nil {
t.Fatal(err)
}
if got := countFn(); got != 2 {
t.Errorf("expecting 2 rows, got %d", got)
}
if err := insert(false); err != errNoSuccess {
t.Errorf("expecting insert to fail with errNoSuccess, got %v", err)
}
if got := countFn(); got != 2 {
t.Errorf("expecting 2 rows, got %d", got)
}
}
func TestPanic(t *testing.T) {
conn, err := sqlite.OpenConn(":memory:", 0)
if err != nil {
t.Fatal(err)
}
defer conn.Close()
if err := Exec(conn, "CREATE TABLE t (c1);", nil); err != nil {
t.Fatal(err)
}
if err := Exec(conn, `INSERT INTO t VALUES ("one");`, nil); err != nil {
t.Fatal(err)
}
defer func() {
p := recover()
if p == nil {
t.Errorf("panic expected")
}
if err, isErr := p.(error); !isErr || !strings.Contains(err.Error(), "sqlite") {
t.Errorf("panic is not an sqlite error: %v", err)
}
count := 0
fn := func(stmt *sqlite.Stmt) error {
count = stmt.ColumnInt(0)
return nil
}
if err := Exec(conn, "SELECT count(*) FROM t;", fn); err != nil {
t.Error(err)
}
if count != 1 {
t.Errorf("got %d rows, want 1", count)
}
}()
if err := doPanic(conn); err != nil {
t.Errorf("unexpected error: %v", err)
}
}
func doPanic(conn *sqlite.Conn) (err error) {
defer Save(conn)(&err)
if err := Exec(conn, `INSERT INTO t VALUES ("hello");`, nil); err != nil {
return err
}
conn.Prep("SELECT bad query") // panics
return nil
}
func TestDone(t *testing.T) {
doneCh := make(chan struct{})
conn, err := sqlite.OpenConn(":memory:", 0)
if err != nil {
t.Fatal(err)
}
defer conn.Close()
conn.SetInterrupt(doneCh)
close(doneCh)
relFn := Save(conn)
relFn(&err)
if code := sqlite.ErrCode(errors.Cause(err)); code != sqlite.SQLITE_INTERRUPT {
t.Errorf("savepoint release function error code is %v, want SQLITE_INTERRUPT", code)
}
}
func TestReleaseTx(t *testing.T) {
conn1, err := sqlite.OpenConn("file::memory:?mode=memory&cache=shared", 0)
if err != nil {
t.Fatal(err)
}
defer conn1.Close()
conn2, err := sqlite.OpenConn("file::memory:?mode=memory&cache=shared", 0)
if err != nil {
t.Fatal(err)
}
defer conn2.Close()
Exec(conn1, "DROP TABLE t;", nil)
if err := Exec(conn1, "CREATE TABLE t (c1);", nil); err != nil {
t.Fatal(err)
}
countFn := func() int {
var count int
fn := func(stmt *sqlite.Stmt) error {
count = stmt.ColumnInt(0)
return nil
}
if err := Exec(conn2, "SELECT count(*) FROM t;", fn); err != nil {
t.Fatal(err)
}
return count
}
errNoSuccess := errors.New("succeed=false")
insert := func(succeed bool) (err error) {
defer Save(conn1)(&err)
if err := Exec(conn1, `INSERT INTO t VALUES ("hello");`, nil); err != nil {
t.Fatal(err)
}
if succeed {
return nil
}
return errNoSuccess
}
if err := insert(true); err != nil {
t.Fatal(err)
}
if got := countFn(); got != 1 {
t.Errorf("expecting 1 row, got %d", got)
}
if err := insert(false); err == nil {
t.Fatal(err)
}
// If the transaction is still open, countFn will get stuck
// on conn2 waiting for conn1's write lock to release.
if got := countFn(); got != 1 {