Commit 26db24c5 authored by Amos Wenger's avatar Amos Wenger

Use Upsert for saveRows

parent df86ec3e
Pipeline #11352 failed with stage
in 56 seconds
......@@ -5,19 +5,10 @@ import (
)
type Context struct {
ScopeMap *ScopeMap
Consumer *state.Consumer
Stats Stats
Error error
Log bool
QueryCount int64
}
type Stats struct {
Inserts int64
Updates int64
Deletes int64
Current int64
ScopeMap *ScopeMap
Consumer *state.Consumer
Error error
Log bool
}
func NewContext(consumer *state.Consumer, models ...interface{}) (*Context, error) {
......
......@@ -31,8 +31,6 @@ func (c *Context) ExecWithSearch(conn *sqlite.Conn, b *builder.Builder, search S
}
func (c *Context) ExecRaw(conn *sqlite.Conn, query string, resultFn ResultFn, args ...interface{}) error {
c.QueryCount++
var startTime time.Time
if c.Log {
startTime = time.Now()
......
......@@ -251,25 +251,7 @@ func Test_ManyToManyThorough(t *testing.T) {
}
originalAuthors := p.Authors
{
beforeSaveQueryCount := c.QueryCount
ordie(c.Save(conn, p, hades.Assoc("Authors")))
pieceSelect := 1
pieceInsert := 1
authorSelect := 1
authorInsert := len(p.Authors)
pieceAuthorSelect := 1
pieceAuthorInsert := len(p.Authors)
total := pieceSelect + pieceInsert +
authorSelect + authorInsert +
pieceAuthorSelect + pieceAuthorInsert
assert.EqualValues(t, total, c.QueryCount-beforeSaveQueryCount)
}
ordie(c.Save(conn, p, hades.Assoc("Authors")))
assertCount(&Piece{}, 1)
assertCount(&Author{}, len(p.Authors))
......@@ -285,23 +267,7 @@ func Test_ManyToManyThorough(t *testing.T) {
}
p.Authors = fewerAuthors
{
beforeSaveQueryCount := c.QueryCount
ordie(c.Save(conn, p, hades.AssocReplace("Authors")))
pieceSelect := 1
authorSelect := 1
pieceAuthorSelect := 1
pieceAuthorDelete := 1
total := pieceSelect +
authorSelect +
pieceAuthorSelect + pieceAuthorDelete
assert.EqualValues(t, total, c.QueryCount-beforeSaveQueryCount)
}
ordie(c.Save(conn, p, hades.AssocReplace("Authors")))
assertCount(&Piece{}, 1)
assertCount(&Author{}, len(originalAuthors))
......@@ -311,23 +277,7 @@ func Test_ManyToManyThorough(t *testing.T) {
p.Authors[2].Name = "Hansel"
{
beforeSaveQueryCount := c.QueryCount
ordie(c.Save(conn, p, hades.AssocReplace("Authors")))
pieceSelect := 1
authorSelect := 1
authorUpdate := 1
pieceAuthorSelect := 1
total := pieceSelect +
authorSelect + authorUpdate +
pieceAuthorSelect
assert.EqualValues(t, total, c.QueryCount-beforeSaveQueryCount)
}
ordie(c.Save(conn, p, hades.AssocReplace("Authors")))
assertCount(&Piece{}, 1)
assertCount(&Author{}, len(originalAuthors))
......@@ -342,26 +292,7 @@ func Test_ManyToManyThorough(t *testing.T) {
Name: "Joseph",
})
{
beforeSaveQueryCount := c.QueryCount
ordie(c.Save(conn, p, hades.AssocReplace("Authors")))
pieceSelect := 1
authorSelect := 1
authorInsert := 1
authorUpdate := 2
pieceAuthorSelect := 1
pieceAuthorInsert := 1
pieceAuthorDelete := 1
total := pieceSelect +
authorSelect + authorInsert + authorUpdate +
pieceAuthorSelect + pieceAuthorInsert + pieceAuthorDelete
assert.EqualValues(t, total, c.QueryCount-beforeSaveQueryCount)
}
ordie(c.Save(conn, p, hades.AssocReplace("Authors")))
assertCount(&Piece{}, 1)
assertCount(&Author{}, len(originalAuthors)+1)
......
......@@ -41,6 +41,7 @@ func Test_Null(t *testing.T) {
if err != nil {
panic(err)
}
defer dbpool.Close()
conn := dbpool.Get(context.Background().Done())
defer dbpool.Put(conn)
......
......@@ -4,8 +4,6 @@ import (
"math"
"reflect"
"github.com/go-xorm/builder"
"crawshaw.io/sqlite"
"github.com/pkg/errors"
)
......@@ -115,90 +113,11 @@ func (c *Context) saveRows(conn *sqlite.Conn, mode AssocMode, inputIface interfa
return nil
}
primaryField := primaryFields[0]
// record should be a *SomeModel, we're effectively doing (*record).<pkColumn>
getKey := func(record reflect.Value) interface{} {
f := record.Elem().FieldByName(primaryField.Name)
if !f.IsValid() {
return nil
}
return f.Interface()
}
// collect primary key values for all of input
var keys []interface{}
for i := 0; i < fresh.Len(); i++ {
record := fresh.Index(i)
keys = append(keys, getKey(record))
}
cacheAddr, err := c.fetchPagedByPK(conn, primaryField.DBName, keys, fresh.Type(), Search{})
if err != nil {
return errors.WithMessage(err, "getting existing rows")
}
cache := cacheAddr.Elem()
// index cached items by their primary key
// so we can look them up in O(1) when comparing
cacheByPK := make(map[interface{}]reflect.Value)
for i := 0; i < cache.Len(); i++ {
record := cache.Index(i)
cacheByPK[getKey(record)] = record
}
// compare cached records with fresh records
var inserts []reflect.Value
var updates = make(map[interface{}]ChangedFields)
doneKeys := make(map[interface{}]bool)
for i := 0; i < fresh.Len(); i++ {
frec := fresh.Index(i)
key := getKey(frec)
if _, ok := doneKeys[key]; ok {
continue
}
doneKeys[key] = true
if crec, ok := cacheByPK[key]; ok {
// frec and crec are *SomeModel, but `RecordEqual` ignores pointer
// equality - we want to compare the contents of the struct
// so we indirect to SomeModel here.
ifrec := frec.Elem().Interface()
icrec := crec.Elem().Interface()
cf, err := DiffRecord(ifrec, icrec, scope)
if err != nil {
return errors.WithMessage(err, "diffing db records")
}
if cf != nil {
updates[key] = cf
}
} else {
inserts = append(inserts, frec)
}
}
c.Stats.Inserts += int64(len(inserts))
c.Stats.Updates += int64(len(updates))
c.Stats.Current += int64(fresh.Len() - len(updates) - len(inserts))
if len(inserts) > 0 {
for _, rec := range inserts {
err := c.Insert(conn, scope, rec)
if err != nil {
return errors.WithMessage(err, "inserting new DB records")
}
}
}
for key, cf := range updates {
eq := cf.ToEq()
err := c.Exec(conn, builder.Update(eq).Into(scope.TableName()).Where(builder.Eq{primaryField.DBName: key}), nil)
rec := fresh.Index(i)
err := c.Upsert(conn, scope, rec)
if err != nil {
return errors.WithMessage(err, "updating DB records")
return errors.WithMessage(err, "upserting DB records")
}
}
......
......@@ -38,6 +38,7 @@ func Test_Select(t *testing.T) {
if err != nil {
panic(err)
}
defer dbpool.Close()
conn := dbpool.Get(context.Background().Done())
defer dbpool.Put(conn)
......@@ -135,6 +136,7 @@ func Test_SelectSquashed(t *testing.T) {
if err != nil {
panic(err)
}
defer dbpool.Close()
conn := dbpool.Get(context.Background().Done())
defer dbpool.Put(conn)
......
package hades
import (
"fmt"
"reflect"
"strings"
"crawshaw.io/sqlite"
"github.com/go-xorm/builder"
)
// TODO: cache me
func (scope *Scope) ToSets() []string {
var sets []string
var processField func(sf *StructField)
processField = func(sf *StructField) {
if sf.IsSquashed {
for _, nsf := range sf.SquashedFields {
processField(nsf)
}
}
if !sf.IsNormal {
return
}
if sf.IsPrimaryKey {
return
}
name := EscapeIdentifier(sf.DBName)
sets = append(sets, fmt.Sprintf("%s=excluded.%s", name, name))
}
for _, sf := range scope.GetStructFields() {
processField(sf)
}
return sets
}
func (c *Context) Upsert(conn *sqlite.Conn, scope *Scope, rec reflect.Value) error {
eq := scope.ToEq(rec)
b := builder.Insert(eq).Into(scope.TableName())
sql, args, err := b.ToSQL()
if err != nil {
return err
}
sets := scope.ToSets()
if len(sets) == 0 {
sql = fmt.Sprintf("%s ON CONFLICT DO NOTHING",
sql,
)
} else {
var pfNames []string
for _, pf := range scope.GetModelStruct().PrimaryFields {
pfNames = append(pfNames, pf.DBName)
}
sql = fmt.Sprintf("%s ON CONFLICT(%s) DO UPDATE SET %s",
sql,
strings.Join(pfNames, ","),
strings.Join(sets, ","),
)
}
return c.ExecRaw(conn, sql, nil, args...)
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment