Commit e8488ebd authored by Amos Wenger's avatar Amos Wenger

Take care of some TODOs

parent f03ab1c1
Pipeline #10423 passed with stage
in 44 seconds
......@@ -3,21 +3,21 @@ package hades
import (
"fmt"
"reflect"
"strings"
"time"
"github.com/go-xorm/builder"
"github.com/pkg/errors"
)
type ChangedFields map[string]interface{}
type ChangedFields map[*StructField]interface{}
func DiffRecord(x, y interface{}, scope *Scope) (ChangedFields, error) {
if x == nil || y == nil {
return nil, errors.New("DiffRecord: arguments must not be nil")
}
// v1 is the fresh record (from API)
// v1 is the fresh record (being saved)
v1 := reflect.ValueOf(x)
// v2 is the cached record (from DB)
// v2 is the cached record (in DB)
v2 := reflect.ValueOf(y)
if v1.Type() != v2.Type() {
return nil, errors.New("DiffRecord: arguments are not the same type")
......@@ -28,36 +28,26 @@ func DiffRecord(x, y interface{}, scope *Scope) (ChangedFields, error) {
return nil, errors.New("DiffRecord: arguments must be structs")
}
ms := scope.GetModelStruct()
var res ChangedFields
for i, n := 0, v1.NumField(); i < n; i++ {
f := typ.Field(i)
fieldName := f.Name
if strings.HasSuffix(fieldName, "ID") {
// ignore
sf, ok := ms.StructFieldsByName[fieldName]
if !ok {
// not listed as a field? ignore
continue
}
if f.Type.Kind() == reflect.Ptr {
// ignore
if !sf.IsNormal {
continue
}
if f.Type.Kind() == reflect.Slice {
// ignore
continue
}
v1f := v1.Field(i)
v2f := v2.Field(i)
if sf, ok := scope.FieldByName(fieldName); ok {
if sf.IsIgnored {
continue
}
} else {
// not listed as a field? ignore
continue
}
iseq, err := eq(v1.Field(i), v2.Field(i))
iseq, err := eq(v1f, v2f)
if err != nil {
return nil, errors.Wrap(err, "while comparing fields")
}
......@@ -66,9 +56,10 @@ func DiffRecord(x, y interface{}, scope *Scope) (ChangedFields, error) {
if res == nil {
res = make(ChangedFields)
}
res[fieldName] = v1.Field(i).Interface()
res[sf] = v1f.Interface()
}
}
return res, nil
}
......@@ -182,3 +173,11 @@ func indirectInterface(v reflect.Value) reflect.Value {
}
return v.Elem()
}
func (cf ChangedFields) ToEq() builder.Eq {
eq := make(builder.Eq)
for sf, v := range cf {
eq[sf.DBName] = v
}
return eq
}
......@@ -20,8 +20,7 @@ func (scope *Scope) ToEq(rec reflect.Value) builder.Eq {
eq := make(builder.Eq)
for _, sf := range scope.GetModelStruct().StructFields {
if sf.Relationship != nil {
// TODO: set IDs here?
if !sf.IsNormal {
continue
}
eq[sf.DBName] = recEl.FieldByName(sf.Name).Interface()
......
......@@ -36,10 +36,11 @@ var modelStructsMap = newModelStructsMap()
// ModelStruct model definition
type ModelStruct struct {
PrimaryFields []*StructField
StructFields []*StructField
ModelType reflect.Type
TableName string
PrimaryFields []*StructField
StructFields []*StructField
StructFieldsByName map[string]*StructField
ModelType reflect.Type
TableName string
}
// StructField model field's struct definition
......@@ -106,6 +107,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
modelStruct.TableName = TableName(reflectType)
modelStruct.ModelType = reflectType
modelStruct.StructFieldsByName = make(map[string]*StructField)
// Get all fields
for i := 0; i < reflectType.NumField(); i++ {
......@@ -466,6 +468,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
modelStructsMap.Set(reflectType, &modelStruct)
for _, sf := range modelStruct.StructFields {
modelStruct.StructFieldsByName[sf.Name] = sf
}
return &modelStruct
}
......
......@@ -96,29 +96,30 @@ func (c *Context) saveJoins(params *SaveParams, conn *sqlite.Conn, mtm *ManyToMa
for _, joinRec := range inserts {
rec := joinRec.Record
if !rec.IsValid() {
if rec.IsValid() {
err := c.Insert(conn, mtm.Scope, rec)
if err != nil {
return errors.Wrap(err, "creating new relation records")
}
} else {
// if not passed an explicit record, make it ourselves
// that typically means the join table doesn't have additional
// columns and is a simple many2many
rec = reflect.New(joinType.Elem())
rec.Elem().FieldByName(mtm.SourceName).Set(reflect.ValueOf(sourceKey))
rec.Elem().FieldByName(mtm.DestinName).Set(reflect.ValueOf(joinRec.DestinKey))
}
// FIXME: that's slow/bad because of ToEq
err := c.Insert(conn, mtm.Scope, rec)
if err != nil {
return errors.Wrap(err, "creating new relation records")
eq := builder.Eq{
mtm.SourceDBName: sourceKey,
mtm.DestinDBName: joinRec.DestinKey,
}
query := builder.Insert(eq).Into(mtm.JoinTable)
err := c.Exec(conn, query, nil)
if err != nil {
return err
}
}
}
for destinKey, rec := range updates {
// FIXME: that's slow/bad
eq := make(builder.Eq)
for k, v := range rec {
eq[ToDBName(k)] = v
}
err := c.Exec(conn, builder.Update(eq).Into(mtm.Scope.TableName()).Where(builder.Eq{mtm.SourceDBName: sourceKey, mtm.DestinDBName: destinKey}), nil)
for destinKey, cf := range updates {
query := builder.Update(cf.ToEq()).Into(mtm.Scope.TableName()).Where(builder.Eq{mtm.SourceDBName: sourceKey, mtm.DestinDBName: destinKey})
err := c.Exec(conn, query, nil)
if err != nil {
return errors.Wrap(err, "updating related records")
}
......
......@@ -223,12 +223,8 @@ func (c *Context) saveRows(conn *sqlite.Conn, params *SaveParams, inputIface int
}
}
for key, rec := range updates {
// FIXME: that's slow/bad
eq := make(builder.Eq)
for k, v := range rec {
eq[ToDBName(k)] = v
}
for key, cf := range updates {
eq := cf.ToEq()
err := c.Exec(conn, builder.Update(eq).Into(scope.TableName()).Where(builder.Eq{primaryField.DBName: key}), nil)
if err != nil {
return errors.Wrap(err, "updating DB records")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment