...
 
Commits (2)
......@@ -5,19 +5,10 @@ import (
)
type Context struct {
ScopeMap *ScopeMap
Consumer *state.Consumer
Stats Stats
Error error
Log bool
QueryCount int64
}
type Stats struct {
Inserts int64
Updates int64
Deletes int64
Current int64
ScopeMap *ScopeMap
Consumer *state.Consumer
Error error
Log bool
}
func NewContext(consumer *state.Consumer, models ...interface{}) (*Context, error) {
......
......@@ -31,8 +31,6 @@ func (c *Context) ExecWithSearch(conn *sqlite.Conn, b *builder.Builder, search S
}
func (c *Context) ExecRaw(conn *sqlite.Conn, query string, resultFn ResultFn, args ...interface{}) error {
c.QueryCount++
var startTime time.Time
if c.Log {
startTime = time.Now()
......
......@@ -251,25 +251,7 @@ func Test_ManyToManyThorough(t *testing.T) {
}
originalAuthors := p.Authors
{
beforeSaveQueryCount := c.QueryCount
ordie(c.Save(conn, p, hades.Assoc("Authors")))
pieceSelect := 1
pieceInsert := 1
authorSelect := 1
authorInsert := len(p.Authors)
pieceAuthorSelect := 1
pieceAuthorInsert := len(p.Authors)
total := pieceSelect + pieceInsert +
authorSelect + authorInsert +
pieceAuthorSelect + pieceAuthorInsert
assert.EqualValues(t, total, c.QueryCount-beforeSaveQueryCount)
}
ordie(c.Save(conn, p, hades.Assoc("Authors")))
assertCount(&Piece{}, 1)
assertCount(&Author{}, len(p.Authors))
......@@ -285,23 +267,7 @@ func Test_ManyToManyThorough(t *testing.T) {
}
p.Authors = fewerAuthors
{
beforeSaveQueryCount := c.QueryCount
ordie(c.Save(conn, p, hades.AssocReplace("Authors")))
pieceSelect := 1
authorSelect := 1
pieceAuthorSelect := 1
pieceAuthorDelete := 1
total := pieceSelect +
authorSelect +
pieceAuthorSelect + pieceAuthorDelete
assert.EqualValues(t, total, c.QueryCount-beforeSaveQueryCount)
}
ordie(c.Save(conn, p, hades.AssocReplace("Authors")))
assertCount(&Piece{}, 1)
assertCount(&Author{}, len(originalAuthors))
......@@ -311,23 +277,7 @@ func Test_ManyToManyThorough(t *testing.T) {
p.Authors[2].Name = "Hansel"
{
beforeSaveQueryCount := c.QueryCount
ordie(c.Save(conn, p, hades.AssocReplace("Authors")))
pieceSelect := 1
authorSelect := 1
authorUpdate := 1
pieceAuthorSelect := 1
total := pieceSelect +
authorSelect + authorUpdate +
pieceAuthorSelect
assert.EqualValues(t, total, c.QueryCount-beforeSaveQueryCount)
}
ordie(c.Save(conn, p, hades.AssocReplace("Authors")))
assertCount(&Piece{}, 1)
assertCount(&Author{}, len(originalAuthors))
......@@ -342,26 +292,7 @@ func Test_ManyToManyThorough(t *testing.T) {
Name: "Joseph",
})
{
beforeSaveQueryCount := c.QueryCount
ordie(c.Save(conn, p, hades.AssocReplace("Authors")))
pieceSelect := 1
authorSelect := 1
authorInsert := 1
authorUpdate := 2
pieceAuthorSelect := 1
pieceAuthorInsert := 1
pieceAuthorDelete := 1
total := pieceSelect +
authorSelect + authorInsert + authorUpdate +
pieceAuthorSelect + pieceAuthorInsert + pieceAuthorDelete
assert.EqualValues(t, total, c.QueryCount-beforeSaveQueryCount)
}
ordie(c.Save(conn, p, hades.AssocReplace("Authors")))
assertCount(&Piece{}, 1)
assertCount(&Author{}, len(originalAuthors)+1)
......
......@@ -41,6 +41,7 @@ func Test_Null(t *testing.T) {
if err != nil {
panic(err)
}
defer dbpool.Close()
conn := dbpool.Get(context.Background().Done())
defer dbpool.Put(conn)
......
......@@ -45,27 +45,3 @@ func (c *Context) fetchPagedByPK(conn *sqlite.Conn, PKDBName string, keys []inte
return result, nil
}
func (c *Context) deletePagedByPK(conn *sqlite.Conn, TableName string, PKDBName string, keys []interface{}, userCond builder.Cond) error {
remainingItems := keys
for len(remainingItems) > 0 {
var pageSize int
if len(remainingItems) > maxSqlVars {
pageSize = maxSqlVars
} else {
pageSize = len(remainingItems)
}
cond := builder.And(userCond, builder.In(PKDBName, remainingItems[:pageSize]...))
query := builder.Delete(cond).From(TableName)
err := c.Exec(conn, query, nil)
if err != nil {
return err
}
remainingItems = remainingItems[pageSize:]
}
return nil
}
package hades
import (
"fmt"
"reflect"
"strings"
"github.com/go-xorm/builder"
"github.com/itchio/hades/sqliteutil2"
"crawshaw.io/sqlite"
......@@ -153,9 +153,8 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, rec interface{}, opts ...
case "has_many":
// if we're in replace mode
if vri.Field.Mode() == AssocModeReplace {
// and it's an actually
// has_many, not a disguised
// many_to_many
// and it's an actual has_many,
// not a disguised many_to_many
if len(vri.ModelStruct.PrimaryFields) == 1 {
// then cull now
cull = true
......@@ -175,7 +174,6 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, rec interface{}, opts ...
}
if cull {
var oldValuePKs []string
rel := vri.Relationship
parentPF := c.NewScope(p.Interface()).PrimaryField()
......@@ -193,49 +191,58 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, rec interface{}, opts ...
return errors.Errorf("Since %v has_many %v, expected %v to have one primary key. Instead, it has primary fields: %s",
pri.Name(), vri.Name(), vri.Name(), strings.Join(pfNames, ", "))
}
valuePF := c.NewScope(v.Interface()).PrimaryField()
if valuePF == nil {
return errors.Errorf("Can't save %v has_many %v: value has no primary keys", pri.Type, vri.Type)
}
q := builder.Select(rel.AssociationForeignDBNames[0]).
From(vri.ModelStruct.TableName).
Where(builder.Eq{
rel.ForeignDBNames[0]: parentPK,
})
passedPFs := make(map[interface{}]struct{})
for i := 0; i < v.Len(); i++ {
rec := v.Index(i)
pf := rec.Elem().FieldByName(rel.AssociationForeignDBNames[0]).Interface()
passedPFs[pf] = struct{}{}
}
err = c.Exec(conn, q, func(stmt *sqlite.Stmt) error {
pk := stmt.ColumnText(0)
oldValuePKs = append(oldValuePKs, pk)
pfTyp := valuePF.Struct.Type
pfKind := pfTyp.Kind()
selectQuery := fmt.Sprintf(`SELECT %s FROM %s WHERE %s = ?`,
EscapeIdentifier(rel.AssociationForeignDBNames[0]),
EscapeIdentifier(vri.ModelStruct.TableName),
EscapeIdentifier(rel.ForeignDBNames[0]),
)
deleteQuery := fmt.Sprintf(`DELETE FROM %s WHERE %s = ? AND %s = ?`,
EscapeIdentifier(vri.ModelStruct.TableName),
EscapeIdentifier(rel.ForeignDBNames[0]),
EscapeIdentifier(rel.AssociationForeignDBNames[0]),
)
var removedPFs []interface{}
err := c.ExecRaw(conn, selectQuery, func(stmt *sqlite.Stmt) error {
var pf interface{}
switch pfKind {
case reflect.Int64:
pf = stmt.ColumnInt64(0)
case reflect.String:
pf = stmt.ColumnText(0)
default:
return errors.Errorf("Unsupported primary key for has_many: %v", pfTyp)
}
if _, ok := passedPFs[pf]; !ok {
removedPFs = append(removedPFs, pf)
}
return nil
})
}, parentPK.Interface())
if err != nil {
return err
}
if len(oldValuePKs) > 0 {
var newValuePKs []string
for i := 0; i < v.Len(); i++ {
newValuePKs = append(newValuePKs, c.NewScope(v.Index(i).Interface()).PrimaryField().Field.String())
}
var newValuePKsMap = make(map[string]struct{})
for _, pk := range newValuePKs {
newValuePKsMap[pk] = struct{}{}
}
var vpksToDelete []interface{}
for _, pk := range oldValuePKs {
if _, ok := newValuePKsMap[pk]; !ok {
vpksToDelete = append(vpksToDelete, pk)
}
}
if len(vpksToDelete) > 0 {
err := c.deletePagedByPK(conn, vri.ModelStruct.TableName, valuePF.DBName, vpksToDelete, builder.NewCond())
if err != nil {
return err
}
for _, pf := range removedPFs {
err := c.ExecRaw(conn, deleteQuery, nil, parentPK.Interface(), pf)
if err != nil {
return err
}
}
}
......
package hades
import (
"fmt"
"reflect"
"crawshaw.io/sqlite"
"github.com/go-xorm/builder"
"github.com/pkg/errors"
)
func (c *Context) saveJoins(conn *sqlite.Conn, mode AssocMode, mtm *ManyToMany) error {
joinType := reflect.PtrTo(mtm.Scope.GetModelStruct().ModelType)
getDestinKey := func(v reflect.Value) interface{} {
return v.Elem().FieldByName(mtm.DestinName).Interface()
}
selectQuery := fmt.Sprintf(`SELECT %s FROM %s WHERE %s = ?`,
EscapeIdentifier(mtm.DestinDBName),
EscapeIdentifier(mtm.JoinTable),
EscapeIdentifier(mtm.SourceDBName),
)
upsertQuery := fmt.Sprintf(`INSERT INTO %s (%s, %s) VALUES (?, ?) ON CONFLICT DO NOTHING`,
EscapeIdentifier(mtm.JoinTable),
EscapeIdentifier(mtm.SourceDBName),
EscapeIdentifier(mtm.DestinDBName),
)
deleteQuery := fmt.Sprintf(`DELETE FROM %s WHERE %s = ? AND %s = ?`,
EscapeIdentifier(mtm.JoinTable),
EscapeIdentifier(mtm.SourceDBName),
EscapeIdentifier(mtm.DestinDBName),
)
deleteAllQuery := fmt.Sprintf(`DELETE FROM %s WHERE %s = ?`,
EscapeIdentifier(mtm.JoinTable),
EscapeIdentifier(mtm.SourceDBName),
)
for sourceKey, joinRecs := range mtm.Values {
cacheAddr := reflect.New(reflect.SliceOf(joinType))
err := c.Select(conn, cacheAddr.Interface(), builder.Eq{mtm.SourceDBName: sourceKey}, Search{})
if err != nil {
return errors.WithMessage(err, "fetching cached records to compare later")
}
cache := cacheAddr.Elem()
cacheByDestinKey := make(map[interface{}]reflect.Value)
for i := 0; i < cache.Len(); i++ {
rec := cache.Index(i)
cacheByDestinKey[getDestinKey(rec)] = rec
}
freshByDestinKey := make(map[interface{}]reflect.Value)
for _, joinRec := range joinRecs {
freshByDestinKey[joinRec.DestinKey] = joinRec.Record
}
var deletes []interface{}
updates := make(map[interface{}]ChangedFields)
insertsByDestinKey := make(map[interface{}]JoinRec)
// compare with cache: will result in delete or update
for i := 0; i < cache.Len(); i++ {
crec := cache.Index(i)
destinKey := getDestinKey(crec)
if frec, ok := freshByDestinKey[destinKey]; ok {
if frec.IsValid() {
// compare to maybe update
ifrec := frec.Elem().Interface()
icrec := crec.Elem().Interface()
cf, err := DiffRecord(ifrec, icrec, mtm.Scope)
if err != nil {
return errors.WithMessage(err, "diffing database records")
}
if cf != nil {
updates[destinKey] = cf
}
for _, jr := range joinRecs {
if jr.Record.IsValid() {
// many to many record was specified
err := c.Upsert(conn, mtm.Scope, jr.Record)
if err != nil {
return err
}
} else {
deletes = append(deletes, destinKey)
// create our own many to many record
err := c.ExecRaw(conn, upsertQuery, nil,
sourceKey, jr.DestinKey,
)
if err != nil {
return err
}
}
}
for _, joinRec := range joinRecs {
if _, ok := cacheByDestinKey[joinRec.DestinKey]; !ok {
insertsByDestinKey[joinRec.DestinKey] = joinRec
if mode == AssocModeReplace {
// this essentially clears all associated records
if len(joinRecs) == 0 {
err := c.ExecRaw(conn, deleteAllQuery, nil, sourceKey)
if err != nil {
return err
}
continue
}
}
if mode == AssocModeReplace && len(deletes) > 0 {
err := c.deletePagedByPK(conn, mtm.JoinTable, mtm.DestinDBName, deletes, builder.Eq{mtm.SourceDBName: sourceKey})
if err != nil {
return errors.WithMessage(err, "deleting extraneous relations")
passedDKs := make(map[interface{}]struct{})
for _, jr := range joinRecs {
passedDKs[jr.DestinKey] = struct{}{}
}
}
for _, joinRec := range insertsByDestinKey {
rec := joinRec.Record
// we have > 0 joinRecs, as checked above
firstDK := joinRecs[0].DestinKey
dkTyp := reflect.TypeOf(firstDK)
dkKind := dkTyp.Kind()
var removedDKs []interface{}
{
err := c.ExecRaw(conn, selectQuery, func(stmt *sqlite.Stmt) error {
var dk interface{}
switch dkKind {
case reflect.Int64:
dk = stmt.ColumnInt64(0)
case reflect.String:
dk = stmt.ColumnText(0)
default:
return errors.Errorf("Unsupported primary key for join table: %v", dkTyp)
}
if rec.IsValid() {
err := c.Insert(conn, mtm.Scope, rec)
if err != nil {
return errors.WithMessage(err, "creating new relation records")
}
} else {
// if not passed an explicit record, make it ourselves
// that typically means the join table doesn't have additional
// columns and is a simple many_to_many
eq := builder.Eq{
mtm.SourceDBName: sourceKey,
mtm.DestinDBName: joinRec.DestinKey,
}
query := builder.Insert(eq).Into(mtm.JoinTable)
err := c.Exec(conn, query, nil)
if _, ok := passedDKs[dk]; !ok {
removedDKs = append(removedDKs, dk)
}
return nil
}, sourceKey)
if err != nil {
return err
}
}
}
for destinKey, cf := range updates {
query := builder.Update(cf.ToEq()).Into(mtm.Scope.TableName()).Where(builder.Eq{mtm.SourceDBName: sourceKey, mtm.DestinDBName: destinKey})
err := c.Exec(conn, query, nil)
if err != nil {
return errors.WithMessage(err, "updating related records")
for _, dk := range removedDKs {
err := c.ExecRaw(conn, deleteQuery, nil, sourceKey, dk)
if err != nil {
return err
}
}
}
}
......
......@@ -4,8 +4,6 @@ import (
"math"
"reflect"
"github.com/go-xorm/builder"
"crawshaw.io/sqlite"
"github.com/pkg/errors"
)
......@@ -115,90 +113,11 @@ func (c *Context) saveRows(conn *sqlite.Conn, mode AssocMode, inputIface interfa
return nil
}
primaryField := primaryFields[0]
// record should be a *SomeModel, we're effectively doing (*record).<pkColumn>
getKey := func(record reflect.Value) interface{} {
f := record.Elem().FieldByName(primaryField.Name)
if !f.IsValid() {
return nil
}
return f.Interface()
}
// collect primary key values for all of input
var keys []interface{}
for i := 0; i < fresh.Len(); i++ {
record := fresh.Index(i)
keys = append(keys, getKey(record))
}
cacheAddr, err := c.fetchPagedByPK(conn, primaryField.DBName, keys, fresh.Type(), Search{})
if err != nil {
return errors.WithMessage(err, "getting existing rows")
}
cache := cacheAddr.Elem()
// index cached items by their primary key
// so we can look them up in O(1) when comparing
cacheByPK := make(map[interface{}]reflect.Value)
for i := 0; i < cache.Len(); i++ {
record := cache.Index(i)
cacheByPK[getKey(record)] = record
}
// compare cached records with fresh records
var inserts []reflect.Value
var updates = make(map[interface{}]ChangedFields)
doneKeys := make(map[interface{}]bool)
for i := 0; i < fresh.Len(); i++ {
frec := fresh.Index(i)
key := getKey(frec)
if _, ok := doneKeys[key]; ok {
continue
}
doneKeys[key] = true
if crec, ok := cacheByPK[key]; ok {
// frec and crec are *SomeModel, but `RecordEqual` ignores pointer
// equality - we want to compare the contents of the struct
// so we indirect to SomeModel here.
ifrec := frec.Elem().Interface()
icrec := crec.Elem().Interface()
cf, err := DiffRecord(ifrec, icrec, scope)
if err != nil {
return errors.WithMessage(err, "diffing db records")
}
if cf != nil {
updates[key] = cf
}
} else {
inserts = append(inserts, frec)
}
}
c.Stats.Inserts += int64(len(inserts))
c.Stats.Updates += int64(len(updates))
c.Stats.Current += int64(fresh.Len() - len(updates) - len(inserts))
if len(inserts) > 0 {
for _, rec := range inserts {
err := c.Insert(conn, scope, rec)
if err != nil {
return errors.WithMessage(err, "inserting new DB records")
}
}
}
for key, cf := range updates {
eq := cf.ToEq()
err := c.Exec(conn, builder.Update(eq).Into(scope.TableName()).Where(builder.Eq{primaryField.DBName: key}), nil)
rec := fresh.Index(i)
err := c.Upsert(conn, scope, rec)
if err != nil {
return errors.WithMessage(err, "updating DB records")
return errors.WithMessage(err, "upserting DB records")
}
}
......
......@@ -38,6 +38,7 @@ func Test_Select(t *testing.T) {
if err != nil {
panic(err)
}
defer dbpool.Close()
conn := dbpool.Get(context.Background().Done())
defer dbpool.Put(conn)
......@@ -135,6 +136,7 @@ func Test_SelectSquashed(t *testing.T) {
if err != nil {
panic(err)
}
defer dbpool.Close()
conn := dbpool.Get(context.Background().Done())
defer dbpool.Put(conn)
......
package hades
import (
"fmt"
"reflect"
"strings"
"crawshaw.io/sqlite"
"github.com/go-xorm/builder"
)
// TODO: cache me
func (scope *Scope) ToSets() []string {
var sets []string
var processField func(sf *StructField)
processField = func(sf *StructField) {
if sf.IsSquashed {
for _, nsf := range sf.SquashedFields {
processField(nsf)
}
}
if !sf.IsNormal {
return
}
if sf.IsPrimaryKey {
return
}
name := EscapeIdentifier(sf.DBName)
sets = append(sets, fmt.Sprintf("%s=excluded.%s", name, name))
}
for _, sf := range scope.GetStructFields() {
processField(sf)
}
return sets
}
func (c *Context) Upsert(conn *sqlite.Conn, scope *Scope, rec reflect.Value) error {
eq := scope.ToEq(rec)
b := builder.Insert(eq).Into(scope.TableName())
sql, args, err := b.ToSQL()
if err != nil {
return err
}
sets := scope.ToSets()
if len(sets) == 0 {
sql = fmt.Sprintf("%s ON CONFLICT DO NOTHING",
sql,
)
} else {
var pfNames []string
for _, pf := range scope.GetModelStruct().PrimaryFields {
pfNames = append(pfNames, pf.DBName)
}
sql = fmt.Sprintf("%s ON CONFLICT(%s) DO UPDATE SET %s",
sql,
strings.Join(pfNames, ","),
strings.Join(sets, ","),
)
}
return c.ExecRaw(conn, sql, nil, args...)
}