Commit d68cb367 authored by Amos Wenger's avatar Amos Wenger

Use upsert for joins as well

parent 26db24c5
Pipeline #11355 failed with stage
in 21 seconds
......@@ -45,27 +45,3 @@ func (c *Context) fetchPagedByPK(conn *sqlite.Conn, PKDBName string, keys []inte
return result, nil
}
func (c *Context) deletePagedByPK(conn *sqlite.Conn, TableName string, PKDBName string, keys []interface{}, userCond builder.Cond) error {
remainingItems := keys
for len(remainingItems) > 0 {
var pageSize int
if len(remainingItems) > maxSqlVars {
pageSize = maxSqlVars
} else {
pageSize = len(remainingItems)
}
cond := builder.And(userCond, builder.In(PKDBName, remainingItems[:pageSize]...))
query := builder.Delete(cond).From(TableName)
err := c.Exec(conn, query, nil)
if err != nil {
return err
}
remainingItems = remainingItems[pageSize:]
}
return nil
}
package hades
import (
"fmt"
"reflect"
"strings"
"github.com/go-xorm/builder"
"github.com/itchio/hades/sqliteutil2"
"crawshaw.io/sqlite"
......@@ -153,9 +153,8 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, rec interface{}, opts ...
case "has_many":
// if we're in replace mode
if vri.Field.Mode() == AssocModeReplace {
// and it's an actually
// has_many, not a disguised
// many_to_many
// and it's an actual has_many,
// not a disguised many_to_many
if len(vri.ModelStruct.PrimaryFields) == 1 {
// then cull now
cull = true
......@@ -175,7 +174,6 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, rec interface{}, opts ...
}
if cull {
var oldValuePKs []string
rel := vri.Relationship
parentPF := c.NewScope(p.Interface()).PrimaryField()
......@@ -193,49 +191,58 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, rec interface{}, opts ...
return errors.Errorf("Since %v has_many %v, expected %v to have one primary key. Instead, it has primary fields: %s",
pri.Name(), vri.Name(), vri.Name(), strings.Join(pfNames, ", "))
}
valuePF := c.NewScope(v.Interface()).PrimaryField()
if valuePF == nil {
return errors.Errorf("Can't save %v has_many %v: value has no primary keys", pri.Type, vri.Type)
}
q := builder.Select(rel.AssociationForeignDBNames[0]).
From(vri.ModelStruct.TableName).
Where(builder.Eq{
rel.ForeignDBNames[0]: parentPK,
})
passedPFs := make(map[interface{}]struct{})
for i := 0; i < v.Len(); i++ {
rec := v.Index(i)
pf := rec.Elem().FieldByName(rel.AssociationForeignDBNames[0]).Interface()
passedPFs[pf] = struct{}{}
}
err = c.Exec(conn, q, func(stmt *sqlite.Stmt) error {
pk := stmt.ColumnText(0)
oldValuePKs = append(oldValuePKs, pk)
pfTyp := valuePF.Struct.Type
pfKind := pfTyp.Kind()
selectQuery := fmt.Sprintf(`SELECT %s FROM %s WHERE %s = ?`,
EscapeIdentifier(rel.AssociationForeignDBNames[0]),
EscapeIdentifier(vri.ModelStruct.TableName),
EscapeIdentifier(rel.ForeignDBNames[0]),
)
deleteQuery := fmt.Sprintf(`DELETE FROM %s WHERE %s = ? AND %s = ?`,
EscapeIdentifier(vri.ModelStruct.TableName),
EscapeIdentifier(rel.ForeignDBNames[0]),
EscapeIdentifier(rel.AssociationForeignDBNames[0]),
)
var removedPFs []interface{}
err := c.ExecRaw(conn, selectQuery, func(stmt *sqlite.Stmt) error {
var pf interface{}
switch pfKind {
case reflect.Int64:
pf = stmt.ColumnInt64(0)
case reflect.String:
pf = stmt.ColumnText(0)
default:
return errors.Errorf("Unsupported primary key for has_many: %v", pfTyp)
}
if _, ok := passedPFs[pf]; !ok {
removedPFs = append(removedPFs, pf)
}
return nil
})
}, parentPK.Interface())
if err != nil {
return err
}
if len(oldValuePKs) > 0 {
var newValuePKs []string
for i := 0; i < v.Len(); i++ {
newValuePKs = append(newValuePKs, c.NewScope(v.Index(i).Interface()).PrimaryField().Field.String())
}
var newValuePKsMap = make(map[string]struct{})
for _, pk := range newValuePKs {
newValuePKsMap[pk] = struct{}{}
}
var vpksToDelete []interface{}
for _, pk := range oldValuePKs {
if _, ok := newValuePKsMap[pk]; !ok {
vpksToDelete = append(vpksToDelete, pk)
}
}
if len(vpksToDelete) > 0 {
err := c.deletePagedByPK(conn, vri.ModelStruct.TableName, valuePF.DBName, vpksToDelete, builder.NewCond())
if err != nil {
return err
}
for _, pf := range removedPFs {
err := c.ExecRaw(conn, deleteQuery, nil, parentPK.Interface(), pf)
if err != nil {
return err
}
}
}
......
package hades
import (
"fmt"
"reflect"
"crawshaw.io/sqlite"
"github.com/go-xorm/builder"
"github.com/pkg/errors"
)
func (c *Context) saveJoins(conn *sqlite.Conn, mode AssocMode, mtm *ManyToMany) error {
joinType := reflect.PtrTo(mtm.Scope.GetModelStruct().ModelType)
getDestinKey := func(v reflect.Value) interface{} {
return v.Elem().FieldByName(mtm.DestinName).Interface()
}
selectQuery := fmt.Sprintf(`SELECT %s FROM %s WHERE %s = ?`,
EscapeIdentifier(mtm.DestinDBName),
EscapeIdentifier(mtm.JoinTable),
EscapeIdentifier(mtm.SourceDBName),
)
upsertQuery := fmt.Sprintf(`INSERT INTO %s (%s, %s) VALUES (?, ?) ON CONFLICT DO NOTHING`,
EscapeIdentifier(mtm.JoinTable),
EscapeIdentifier(mtm.SourceDBName),
EscapeIdentifier(mtm.DestinDBName),
)
deleteQuery := fmt.Sprintf(`DELETE FROM %s WHERE %s = ? AND %s = ?`,
EscapeIdentifier(mtm.JoinTable),
EscapeIdentifier(mtm.SourceDBName),
EscapeIdentifier(mtm.DestinDBName),
)
deleteAllQuery := fmt.Sprintf(`DELETE FROM %s WHERE %s = ?`,
EscapeIdentifier(mtm.JoinTable),
EscapeIdentifier(mtm.SourceDBName),
)
for sourceKey, joinRecs := range mtm.Values {
cacheAddr := reflect.New(reflect.SliceOf(joinType))
err := c.Select(conn, cacheAddr.Interface(), builder.Eq{mtm.SourceDBName: sourceKey}, Search{})
if err != nil {
return errors.WithMessage(err, "fetching cached records to compare later")
}
cache := cacheAddr.Elem()
cacheByDestinKey := make(map[interface{}]reflect.Value)
for i := 0; i < cache.Len(); i++ {
rec := cache.Index(i)
cacheByDestinKey[getDestinKey(rec)] = rec
}
freshByDestinKey := make(map[interface{}]reflect.Value)
for _, joinRec := range joinRecs {
freshByDestinKey[joinRec.DestinKey] = joinRec.Record
}
var deletes []interface{}
updates := make(map[interface{}]ChangedFields)
insertsByDestinKey := make(map[interface{}]JoinRec)
// compare with cache: will result in delete or update
for i := 0; i < cache.Len(); i++ {
crec := cache.Index(i)
destinKey := getDestinKey(crec)
if frec, ok := freshByDestinKey[destinKey]; ok {
if frec.IsValid() {
// compare to maybe update
ifrec := frec.Elem().Interface()
icrec := crec.Elem().Interface()
cf, err := DiffRecord(ifrec, icrec, mtm.Scope)
if err != nil {
return errors.WithMessage(err, "diffing database records")
}
if cf != nil {
updates[destinKey] = cf
}
for _, jr := range joinRecs {
if jr.Record.IsValid() {
// many to many record was specified
err := c.Upsert(conn, mtm.Scope, jr.Record)
if err != nil {
return err
}
} else {
deletes = append(deletes, destinKey)
// create our own many to many record
err := c.ExecRaw(conn, upsertQuery, nil,
sourceKey, jr.DestinKey,
)
if err != nil {
return err
}
}
}
for _, joinRec := range joinRecs {
if _, ok := cacheByDestinKey[joinRec.DestinKey]; !ok {
insertsByDestinKey[joinRec.DestinKey] = joinRec
if mode == AssocModeReplace {
// this essentially clears all associated records
if len(joinRecs) == 0 {
err := c.ExecRaw(conn, deleteAllQuery, nil, sourceKey)
if err != nil {
return err
}
continue
}
}
if mode == AssocModeReplace && len(deletes) > 0 {
err := c.deletePagedByPK(conn, mtm.JoinTable, mtm.DestinDBName, deletes, builder.Eq{mtm.SourceDBName: sourceKey})
if err != nil {
return errors.WithMessage(err, "deleting extraneous relations")
passedDKs := make(map[interface{}]struct{})
for _, jr := range joinRecs {
passedDKs[jr.DestinKey] = struct{}{}
}
}
for _, joinRec := range insertsByDestinKey {
rec := joinRec.Record
// we have > 0 joinRecs, as checked above
firstDK := joinRecs[0].DestinKey
dkTyp := reflect.TypeOf(firstDK)
dkKind := dkTyp.Kind()
var removedDKs []interface{}
{
err := c.ExecRaw(conn, selectQuery, func(stmt *sqlite.Stmt) error {
var dk interface{}
switch dkKind {
case reflect.Int64:
dk = stmt.ColumnInt64(0)
case reflect.String:
dk = stmt.ColumnText(0)
default:
return errors.Errorf("Unsupported primary key for join table: %v", dkTyp)
}
if rec.IsValid() {
err := c.Insert(conn, mtm.Scope, rec)
if err != nil {
return errors.WithMessage(err, "creating new relation records")
}
} else {
// if not passed an explicit record, make it ourselves
// that typically means the join table doesn't have additional
// columns and is a simple many_to_many
eq := builder.Eq{
mtm.SourceDBName: sourceKey,
mtm.DestinDBName: joinRec.DestinKey,
}
query := builder.Insert(eq).Into(mtm.JoinTable)
err := c.Exec(conn, query, nil)
if _, ok := passedDKs[dk]; !ok {
removedDKs = append(removedDKs, dk)
}
return nil
}, sourceKey)
if err != nil {
return err
}
}
}
for destinKey, cf := range updates {
query := builder.Update(cf.ToEq()).Into(mtm.Scope.TableName()).Where(builder.Eq{mtm.SourceDBName: sourceKey, mtm.DestinDBName: destinKey})
err := c.Exec(conn, query, nil)
if err != nil {
return errors.WithMessage(err, "updating related records")
for _, dk := range removedDKs {
err := c.ExecRaw(conn, deleteQuery, nil, sourceKey, dk)
if err != nil {
return err
}
}
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment