...
 
Commits (3)
......@@ -2,6 +2,7 @@ package hades_test
import (
"context"
"fmt"
"testing"
"time"
......@@ -46,7 +47,7 @@ func Test_AutoMigrate(t *testing.T) {
assert.False(t, pti[1].PrimaryKey)
assert.False(t, pti[1].NotNull)
ordie(c.SaveOne(conn, &User{ID: 123, FirstName: "Joanna"}))
ordie(c.Save(conn, &User{ID: 123, FirstName: "Joanna"}))
u := &User{}
foundUser, err := c.SelectOne(conn, u, builder.Eq{"id": 123})
ordie(err)
......@@ -193,7 +194,7 @@ func Test_AutoMigrateAllValidTypes(t *testing.T) {
FirstName: "Jeremy",
HeartRate: 3.14,
}
ordie(c.SaveOne(conn, h1))
ordie(c.Save(conn, h1))
h2 := &Humanoid{}
found, err := c.SelectOne(conn, h2, builder.Eq{"id": 12})
......@@ -269,6 +270,6 @@ func Test_AutoMigrateSquash(t *testing.T) {
func ordie(err error) {
if err != nil {
panic(err)
panic(fmt.Sprintf("%+v", err))
}
}
......@@ -36,22 +36,18 @@ func Test_BelongsTo(t *testing.T) {
Desc: "Consumer-grade flamethrowers",
}
t.Log("Saving one fate")
wtest.Must(t, c.SaveOne(conn, someFate))
wtest.Must(t, c.Save(conn, someFate))
lea := &Human{
ID: 3,
FateID: someFate.ID,
}
t.Log("Saving one human")
wtest.Must(t, c.SaveOne(conn, lea))
wtest.Must(t, c.Save(conn, lea))
t.Log("Preloading lea")
c.Preload(conn, &hades.PreloadParams{
Record: lea,
Fields: []hades.PreloadField{
{Name: "Fate"},
},
})
c.Preload(conn, lea, hades.Assoc("Fate"))
assert.NotNil(t, lea.Fate)
assert.EqualValues(t, someFate.Desc, lea.Fate.Desc)
})
......@@ -64,10 +60,7 @@ func Test_BelongsTo(t *testing.T) {
Desc: "Book authorship",
},
}
c.Save(conn, &hades.SaveParams{
Record: lea,
Assocs: []string{"Fate"},
})
wtest.Must(t, c.Save(conn, lea, hades.Assoc("Fate")))
fate := &Fate{}
found, err := c.SelectOne(conn, fate, builder.Eq{"id": 421})
......@@ -81,27 +74,21 @@ func Test_BelongsTo(t *testing.T) {
ID: 3,
Desc: "Space rodeo",
}
wtest.Must(t, c.SaveOne(conn, fate))
wtest.Must(t, c.Save(conn, fate))
human := &Human{
ID: 6,
FateID: 3,
}
wtest.Must(t, c.SaveOne(conn, human))
wtest.Must(t, c.Save(conn, human))
joke := &Joke{
ID: "neuf",
HumanID: 6,
}
wtest.Must(t, c.SaveOne(conn, joke))
wtest.Must(t, c.Save(conn, joke))
c.Preload(conn, &hades.PreloadParams{
Record: joke,
Fields: []hades.PreloadField{
{Name: "Human"},
{Name: "Human.Fate"},
},
})
c.Preload(conn, joke, hades.Assoc("Human", hades.Assoc("Fate")))
assert.NotNil(t, joke.Human)
assert.NotNil(t, joke.Human.Fate)
assert.EqualValues(t, "Space rodeo", joke.Human.Fate.Desc)
......
......@@ -37,7 +37,7 @@ func Test_Delete(t *testing.T) {
var count int64
var err error
wtest.Must(t, c.SaveOne(conn, stories))
wtest.Must(t, c.Save(conn, stories))
count, err = c.Count(conn, &Story{}, builder.NewCond())
wtest.Must(t, err)
......
......@@ -121,7 +121,7 @@ func iseq(sf *StructField, v1f reflect.Value, v2f reflect.Value) (bool, error) {
func (cf ChangedFields) ToEq() builder.Eq {
eq := make(builder.Eq)
for sf, v := range cf {
eq[sf.DBName] = DBValue(v)
eq[EscapeIdentifier(sf.DBName)] = DBValue(v)
}
return eq
}
package hades
import (
"fmt"
"reflect"
"github.com/pkg/errors"
......@@ -44,7 +43,7 @@ func (field *Field) Set(value interface{}) (err error) {
if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
} else {
err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type())
err = errors.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type())
}
}
} else {
......
......@@ -32,8 +32,13 @@ func withContext(t *testing.T, models []interface{}, f WithContextFunc) {
wtest.Must(t, err)
c.Log = true
// wtest.Must(t, c.AutoMigrate(conn))
c.AutoMigrate(conn)
wtest.Must(t, c.AutoMigrate(conn))
defer func() {
c.ScopeMap.Each(func(scope *hades.Scope) error {
return c.ExecRaw(conn, "DROP TABLE "+scope.TableName(), nil)
})
}()
f(conn, c)
}
......@@ -43,12 +43,12 @@ func Test_HasMany(t *testing.T) {
{ID: 11, Label: "Ability to not repeat oneself"},
},
}
wtest.Must(t, c.Save(conn, &hades.SaveParams{Record: p1}))
wtest.Must(t, c.Save(conn, p1, hades.Assoc("Qualities")))
assertCount(&Programmer{}, 1)
assertCount(&Quality{}, 3)
p1.Qualities[2].Label = "Inspiration again"
wtest.Must(t, c.Save(conn, &hades.SaveParams{Record: p1}))
wtest.Must(t, c.Save(conn, p1, hades.Assoc("Qualities")))
assertCount(&Programmer{}, 1)
assertCount(&Quality{}, 3)
{
......@@ -67,45 +67,35 @@ func Test_HasMany(t *testing.T) {
},
}
programmers := []*Programmer{p1, p2}
wtest.Must(t, c.Save(conn, &hades.SaveParams{Record: programmers}))
wtest.Must(t, c.Save(conn, programmers, hades.Assoc("Qualities")))
assertCount(&Programmer{}, 2)
assertCount(&Quality{}, 5)
p1bis := &Programmer{ID: 3}
pp := &hades.PreloadParams{
Record: p1bis,
Fields: []hades.PreloadField{
{Name: "Qualities"},
},
}
wtest.Must(t, c.Preload(conn, pp))
wtest.Must(t, c.Preload(conn, p1bis, hades.Assoc("Qualities")))
assert.EqualValues(t, 3, len(p1bis.Qualities), "preload has_many")
wtest.Must(t, c.Preload(conn, pp))
wtest.Must(t, c.Preload(conn, p1bis, hades.Assoc("Qualities")))
assert.EqualValues(t, 3, len(p1bis.Qualities), "preload replaces, doesn't append")
pp.Fields[0] = hades.PreloadField{
Name: "Qualities",
Search: hades.Search().OrderBy("id asc"),
}
wtest.Must(t, c.Preload(conn, pp))
wtest.Must(t, c.Preload(conn, p1bis,
hades.AssocWithSearch("Qualities", hades.Search().OrderBy("id ASC"))),
)
assert.EqualValues(t, "Inspiration", p1bis.Qualities[0].Label, "orders by (asc)")
pp.Fields[0] = hades.PreloadField{
Name: "Qualities",
Search: hades.Search().OrderBy("id desc"),
}
wtest.Must(t, c.Preload(conn, pp))
wtest.Must(t, c.Preload(conn, p1bis,
hades.AssocWithSearch("Qualities", hades.Search().OrderBy("id DESC"))),
)
assert.EqualValues(t, "Inspiration again", p1bis.Qualities[0].Label, "orders by (desc)")
// no fields
assert.Error(t, c.Preload(conn, &hades.PreloadParams{Record: p1bis}))
assert.Error(t, c.Preload(conn, p1bis))
// not a model
assert.Error(t, c.Preload(conn, &hades.PreloadParams{Record: 42, Fields: pp.Fields}))
assert.Error(t, c.Preload(conn, 42, hades.Assoc("Qualities")))
// non-existent relation
assert.Error(t, c.Preload(conn, &hades.PreloadParams{Record: p1bis, Fields: []hades.PreloadField{{Name: "Woops"}}}))
assert.Error(t, c.Preload(conn, 42, hades.Assoc("Woops")))
})
}
......@@ -154,10 +144,7 @@ func Test_HasManyThorough(t *testing.T) {
t.Logf("...snip tons of INSERT...")
c.Log = false
ordie(c.Save(conn, &hades.SaveParams{
Record: car,
Assocs: []string{"Traits"},
}))
ordie(c.Save(conn, car, hades.Assoc("Traits")))
c.Log = true
numTraits := len(car.Traits)
......@@ -168,20 +155,13 @@ func Test_HasManyThorough(t *testing.T) {
car.Traits = nil
ordie(c.Save(conn, &hades.SaveParams{
Record: car,
Assocs: []string{"Traits"},
DontCull: []interface{}{&Trait{}},
}))
ordie(c.Save(conn, car, hades.Assoc("Traits")))
traitCount, err = c.Count(conn, &Trait{}, builder.NewCond())
ordie(err)
assert.EqualValues(t, numTraits, traitCount, "traits should still exist after partial-join save")
ordie(c.Save(conn, &hades.SaveParams{
Record: car,
Assocs: []string{"Traits"},
}))
ordie(c.Save(conn, car, hades.AssocReplace("Traits")))
traitCount, err = c.Count(conn, &Trait{}, builder.NewCond())
ordie(err)
......
......@@ -51,12 +51,12 @@ func Test_HasOne(t *testing.T) {
assert.EqualValues(t, expectedCount, count)
}
wtest.Must(t, c.Save(conn, &hades.SaveParams{Record: country, Assocs: []string{"Specialty"}}))
wtest.Must(t, c.Save(conn, country, hades.OmitRoot(), hades.Assoc("Specialty", hades.Assoc("Drawback"))))
assertCount(&Country{}, 0)
assertCount(&Specialty{}, 1)
assertCount(&Drawback{}, 1)
wtest.Must(t, c.Save(conn, &hades.SaveParams{Record: country}))
wtest.Must(t, c.Save(conn, country, hades.Assoc("Specialty", hades.Assoc("Drawback"))))
assertCount(&Country{}, 1)
assertCount(&Specialty{}, 1)
assertCount(&Drawback{}, 1)
......@@ -70,12 +70,8 @@ func Test_HasOne(t *testing.T) {
countries = append(countries, country)
}
wtest.Must(t, c.Preload(conn, &hades.PreloadParams{
Record: countries,
Fields: []hades.PreloadField{
{Name: "Specialty"},
{Name: "Specialty.Drawback"},
},
}))
wtest.Must(t, c.Preload(conn, countries,
hades.Assoc("Specialty",
hades.Assoc("Drawback"))))
})
}
......@@ -32,7 +32,7 @@ func (scope *Scope) ToEq(rec reflect.Value) builder.Eq {
if !sf.IsNormal {
return
}
eq[sf.DBName] = DBValue(field.Interface())
eq[EscapeIdentifier(sf.DBName)] = DBValue(field.Interface())
}
for _, sf := range scope.GetModelStruct().StructFields {
......
......@@ -39,9 +39,7 @@ func Test_ManyToMany(t *testing.T) {
},
}
t.Logf("saving just fr")
wtest.Must(t, c.Save(conn, &hades.SaveParams{
Record: fr,
}))
wtest.Must(t, c.Save(conn, fr, hades.Assoc("Words")))
assertCount := func(model interface{}, expectedCount int64) {
t.Helper()
......@@ -62,10 +60,7 @@ func Test_ManyToMany(t *testing.T) {
},
}
t.Logf("saving fr+en")
wtest.Must(t, c.Save(conn, &hades.SaveParams{
Record: []*Language{fr, en},
}))
wtest.Must(t, c.Save(conn, []*Language{fr, en}, hades.Assoc("Words")))
assertCount(&Language{}, 2)
assertCount(&Word{}, 2)
assertCount(&LanguageWord{}, 4)
......@@ -75,19 +70,14 @@ func Test_ManyToMany(t *testing.T) {
{ID: "Wreck"},
{ID: "Nervous"},
}
wtest.Must(t, c.Save(conn, &hades.SaveParams{
Record: []*Language{en},
DontCull: []interface{}{&LanguageWord{}},
}))
wtest.Must(t, c.Save(conn, []*Language{en}, hades.Assoc("Words")))
assertCount(&Language{}, 2)
assertCount(&Word{}, 4)
assertCount(&LanguageWord{}, 6)
t.Logf("replacing all english words")
wtest.Must(t, c.Save(conn, &hades.SaveParams{
Record: []*Language{en},
}))
wtest.Must(t, c.Save(conn, []*Language{en}, hades.AssocReplace("Words")))
assertCount(&Language{}, 2)
assertCount(&Word{}, 4)
......@@ -95,10 +85,7 @@ func Test_ManyToMany(t *testing.T) {
t.Logf("adding commentary")
en.Words[0].Comment = "punk band reference"
wtest.Must(t, c.Save(conn, &hades.SaveParams{
Record: []*Language{en},
}))
wtest.Must(t, c.Save(conn, []*Language{en}, hades.Assoc("Words")))
assertCount(&Language{}, 2)
assertCount(&Word{}, 4)
assertCount(&LanguageWord{}, 4)
......@@ -115,12 +102,7 @@ func Test_ManyToMany(t *testing.T) {
{ID: fr.ID},
{ID: en.ID},
}
err := c.Preload(conn, &hades.PreloadParams{
Record: langs,
Fields: []hades.PreloadField{
{Name: "Words"},
},
})
err := c.Preload(conn, langs, hades.Assoc("Words"))
// many_to_many preload is not implemented
assert.Error(t, err)
})
......@@ -179,9 +161,29 @@ func Test_ManyToManyRevenge(t *testing.T) {
}
}
p := makeProfile()
c.Save(conn, &hades.SaveParams{
Record: p,
})
wtest.Must(t, c.Save(conn, p,
hades.Assoc("ProfileGames",
hades.Assoc("Game"),
),
))
var names []struct {
Name string
}
wtest.Must(t, c.ExecWithSearch(conn,
builder.Select("games.title").
From("games").
LeftJoin("profile_games", builder.Expr("profile_games.game_id = games.id")),
hades.Search().OrderBy("profile_games.\"order\" ASC"),
c.IntoRowsScanner(&names),
))
assert.EqualValues(t, []struct {
Name string
}{
{"First offensive"},
{"Seconds until midnight"},
{"Three was company"},
}, names)
})
}
......@@ -237,7 +239,7 @@ func Test_ManyToManyThorough(t *testing.T) {
{
beforeSaveQueryCount := c.QueryCount
ordie(c.SaveOne(conn, p))
ordie(c.Save(conn, p, hades.Assoc("Authors")))
pieceSelect := 1
pieceInsert := 1
......@@ -271,7 +273,7 @@ func Test_ManyToManyThorough(t *testing.T) {
{
beforeSaveQueryCount := c.QueryCount
ordie(c.SaveOne(conn, p))
ordie(c.Save(conn, p, hades.AssocReplace("Authors")))
pieceSelect := 1
......@@ -297,7 +299,7 @@ func Test_ManyToManyThorough(t *testing.T) {
{
beforeSaveQueryCount := c.QueryCount
ordie(c.SaveOne(conn, p))
ordie(c.Save(conn, p, hades.AssocReplace("Authors")))
pieceSelect := 1
......@@ -328,7 +330,7 @@ func Test_ManyToManyThorough(t *testing.T) {
{
beforeSaveQueryCount := c.QueryCount
ordie(c.SaveOne(conn, p))
ordie(c.Save(conn, p, hades.AssocReplace("Authors")))
pieceSelect := 1
......@@ -358,14 +360,14 @@ func Test_ManyToManyThorough(t *testing.T) {
})
}
ordie(c.SaveOne(conn, p))
ordie(c.Save(conn, p, hades.AssocReplace("Authors")))
assertCount(&Piece{}, 1)
assertCount(&Author{}, len(originalAuthors)+1+1200)
assertCount(&PieceAuthor{}, len(p.Authors))
p.Authors = nil
ordie(c.SaveOne(conn, p))
ordie(c.Save(conn, p, hades.AssocReplace("Authors")))
assertCount(&Piece{}, 1)
assertCount(&Author{}, len(originalAuthors)+1+1200)
......
......@@ -52,7 +52,7 @@ func Test_Null(t *testing.T) {
ID: 123,
}
ordie(c.SaveOne(conn, d))
ordie(c.Save(conn, d))
{
dd := &Download{}
found, err := c.SelectOne(conn, dd, builder.Eq{"id": 123})
......@@ -74,7 +74,7 @@ func Test_Null(t *testing.T) {
finishedAt := time.Now()
d.FinishedAt = &finishedAt
ordie(c.SaveOne(conn, d))
ordie(c.Save(conn, d))
{
dd := &Download{}
......@@ -89,7 +89,7 @@ func Test_Null(t *testing.T) {
}
d.ErrorMessage = nil
ordie(c.SaveOne(conn, d))
ordie(c.Save(conn, d))
{
dd := &Download{}
......
......@@ -31,11 +31,11 @@ func (c *Context) fetchPagedByPK(conn *sqlite.Conn, PKDBName string, keys []inte
}
pageAddr := reflect.New(sliceType)
cond := builder.In(PKDBName, remainingItems[:pageSize]...)
cond := builder.In(EscapeIdentifier(PKDBName), remainingItems[:pageSize]...)
err := c.Select(conn, pageAddr.Interface(), cond, search)
if err != nil {
return result, errors.Wrap(err, "performing page fetch")
return result, errors.WithMessage(err, "performing page fetch")
}
appended := reflect.AppendSlice(resultVal, pageAddr.Elem())
......
package hades
type AssocMode int
const (
AssocModeAppend AssocMode = iota
AssocModeReplace
)
type assocField struct {
name string
search *SearchParams
mode AssocMode
children []AssocField
}
type saveParams struct {
assocs []AssocField
omitRoot bool
}
type preloadParams struct {
assocs []AssocField
}
type SaveParam interface {
ApplyToSaveParams(sp *saveParams)
}
type PreloadParam interface {
ApplyToPreloadParams(pp *preloadParams)
}
type AssocField interface {
SaveParam
PreloadParam
Name() string
Mode() AssocMode
Search() *SearchParams
Children() []AssocField
}
// -------------
// OmitRoot tells save to not save the record passed,
// but only associations
func OmitRoot() SaveParam {
return &omitRoot{}
}
type omitRoot struct{}
func (o *omitRoot) ApplyToSaveParams(sp *saveParams) {
sp.omitRoot = true
}
// Assoc tells save to save the specified association,
// but not to remove any existing associated records, even if
// they're not listed anymore
func Assoc(fieldName string, children ...AssocField) AssocField {
return &assocField{
name: fieldName,
mode: AssocModeAppend,
children: children,
}
}
// AssocReplace tells save to save the specified assocation,
// and to remove any associated records that are no longer listed
func AssocReplace(fieldName string, children ...AssocField) AssocField {
return &assocField{
name: fieldName,
mode: AssocModeReplace,
children: children,
}
}
func AssocWithSearch(fieldName string, search *SearchParams, children ...AssocField) AssocField {
return &assocField{
name: fieldName,
mode: AssocModeAppend,
search: search,
children: children,
}
}
func (f *assocField) ApplyToSaveParams(sp *saveParams) {
sp.assocs = append(sp.assocs, f)
}
func (f *assocField) ApplyToPreloadParams(pp *preloadParams) {
pp.assocs = append(pp.assocs, f)
}
func (f *assocField) Name() string {
return f.name
}
func (f *assocField) Mode() AssocMode {
return f.mode
}
func (f *assocField) Children() []AssocField {
return f.children
}
func (f *assocField) Search() *SearchParams {
return f.search
}
......@@ -3,73 +3,19 @@ package hades
import (
"fmt"
"reflect"
"strings"
"crawshaw.io/sqlite"
"github.com/pkg/errors"
)
type PreloadParams struct {
Record interface{}
// Fields to preload, for example []string{"CollectionGames", "CollectionGames.Game"}
Fields []PreloadField
}
type PreloadField struct {
Name string
Search *SearchParams
}
type Node struct {
Name string
Search *SearchParams
Field PreloadField
Children map[string]*Node
}
func NewNode(name string) *Node {
return &Node{
Name: name,
Children: make(map[string]*Node),
func (c *Context) Preload(conn *sqlite.Conn, rec interface{}, opts ...PreloadParam) error {
params := &preloadParams{}
for _, o := range opts {
o.ApplyToPreloadParams(params)
}
}
func (n *Node) String() string {
var res []string
res = append(res, fmt.Sprintf("- %s%s", n.Name, n.Search))
for _, c := range n.Children {
for _, cl := range strings.Split(c.String(), "\n") {
res = append(res, " "+cl)
}
}
return strings.Join(res, "\n")
}
func (n *Node) Add(pf PreloadField) {
tokens := strings.Split(pf.Name, ".")
name := tokens[0]
c, ok := n.Children[name]
if !ok {
c = NewNode(name)
n.Children[name] = c
}
if len(tokens) > 1 {
pfc := pf
pfc.Name = strings.Join(tokens[1:], ".")
c.Add(pfc)
} else {
c.Field = pf
c.Search = pf.Search
}
}
func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error {
rec := params.Record
if len(params.Fields) == 0 {
return errors.New("Preload expects a non-empty list in Fields")
if len(params.assocs) == 0 {
return errors.Errorf("Cannot preload 0 assocs")
}
val := reflect.ValueOf(rec)
......@@ -81,43 +27,31 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error {
valtyp = valtyp.Elem()
}
if valtyp.Kind() != reflect.Ptr {
return fmt.Errorf("Preload expects a []*Model or *Model, but it was passed a %v instead", val.Type())
return errors.Errorf("Preload expects a []*Model or *Model, but it was passed a %v instead", val.Type())
}
riMap := make(RecordInfoMap)
rootName := fmt.Sprintf("%v", valtyp)
typeTree, err := c.WalkType(riMap, rootName, valtyp, make(VisitMap), nil)
if err != nil {
return errors.Wrap(err, "waking type tree")
rootField := &assocField{
name: fmt.Sprintf("%v", valtyp),
mode: AssocModeAppend,
children: params.assocs,
}
valTree := NewNode(rootName)
for _, field := range params.Fields {
valTree.Add(field)
rootInfo, err := c.WalkType(riMap, rootField, valtyp)
if err != nil {
return errors.WithMessage(err, "waking type tree")
}
var walk func(p reflect.Value, pri *RecordInfo, pvt *Node) error
walk = func(p reflect.Value, pri *RecordInfo, pvt *Node) error {
for _, cvt := range pvt.Children {
var cri *RecordInfo
for _, c := range pri.Children {
if c.Name == cvt.Name {
cri = c
break
}
}
if cri == nil {
return fmt.Errorf("Relation not found: %s.%s", pri.Name, cvt.Name)
}
ptyp := p.Type()
if ptyp.Kind() == reflect.Slice {
ptyp = ptyp.Elem()
}
if ptyp.Kind() != reflect.Ptr {
return fmt.Errorf("walk expects a []*Model or *Model, but it was passed a %v instead", p.Type())
}
var walk func(p reflect.Value, pri *RecordInfo) error
walk = func(p reflect.Value, pri *RecordInfo) error {
ptyp := p.Type()
if ptyp.Kind() == reflect.Slice {
ptyp = ptyp.Elem()
}
if ptyp.Kind() != reflect.Ptr {
return errors.Errorf("walk expects a []*Model or *Model, but it was passed a %v instead", p.Type())
}
for _, cri := range pri.Children {
freshAddr := reflect.New(reflect.SliceOf(cri.Type))
var ps reflect.Value
......@@ -136,9 +70,9 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error {
}
var err error
freshAddr, err = c.fetchPagedByPK(conn, cri.Relationship.ForeignDBNames[0], keys, reflect.SliceOf(cri.Type), cvt.Search)
freshAddr, err = c.fetchPagedByPK(conn, cri.Relationship.ForeignDBNames[0], keys, reflect.SliceOf(cri.Type), cri.Field.Search())
if err != nil {
return errors.Wrap(err, "fetching has_many records (paginated)")
return errors.WithMessage(err, "fetching has_many records (paginated)")
}
pByFK := make(map[interface{}]reflect.Value)
......@@ -149,7 +83,7 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error {
// reset slices so if preload is called more than once,
// it doesn't keep appending
field := rec.Elem().FieldByName(cvt.Name)
field := rec.Elem().FieldByName(cri.Name())
field.Set(reflect.New(field.Type()).Elem())
}
......@@ -157,7 +91,7 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error {
for i := 0; i < fresh.Len(); i++ {
fk := fresh.Index(i).Elem().FieldByName(cri.Relationship.ForeignFieldNames[0]).Interface()
if p, ok := pByFK[fk]; ok {
dest := p.Elem().FieldByName(cvt.Name)
dest := p.Elem().FieldByName(cri.Name())
dest.Set(reflect.Append(dest, fresh.Index(i)))
}
}
......@@ -169,9 +103,9 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error {
}
var err error
freshAddr, err = c.fetchPagedByPK(conn, cri.Relationship.ForeignDBNames[0], keys, reflect.SliceOf(cri.Type), cvt.Search)
freshAddr, err = c.fetchPagedByPK(conn, cri.Relationship.ForeignDBNames[0], keys, reflect.SliceOf(cri.Type), cri.Field.Search())
if err != nil {
return errors.Wrap(err, "fetching has_one records (paginated)")
return errors.WithMessage(err, "fetching has_one records (paginated)")
}
fresh := freshAddr.Elem()
......@@ -186,7 +120,7 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error {
prec := ps.Index(i)
fk := prec.Elem().FieldByName(cri.Relationship.AssociationForeignFieldNames[0]).Interface()
if crec, ok := freshByFK[fk]; ok {
prec.Elem().FieldByName(cvt.Name).Set(crec)
prec.Elem().FieldByName(cri.Name()).Set(crec)
}
}
case "belongs_to":
......@@ -197,9 +131,9 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error {
}
var err error
freshAddr, err = c.fetchPagedByPK(conn, cri.Relationship.AssociationForeignDBNames[0], keys, reflect.SliceOf(cri.Type), cvt.Search)
freshAddr, err = c.fetchPagedByPK(conn, cri.Relationship.AssociationForeignDBNames[0], keys, reflect.SliceOf(cri.Type), cri.Field.Search())
if err != nil {
return errors.Wrap(err, "fetching belongs_to records (paginated)")
return errors.WithMessage(err, "fetching belongs_to records (paginated)")
}
fresh := freshAddr.Elem()
......@@ -214,23 +148,23 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error {
prec := ps.Index(i)
fk := prec.Elem().FieldByName(cri.Relationship.ForeignFieldNames[0]).Interface()
if crec, ok := freshByFK[fk]; ok {
prec.Elem().FieldByName(cvt.Name).Set(crec)
prec.Elem().FieldByName(cri.Name()).Set(crec)
}
}
default:
return fmt.Errorf("Preload doesn't know how to handle %s relationships", cri.Relationship.Kind)
return errors.Errorf("Preload doesn't know how to handle %s relationships", cri.Relationship.Kind)
}
fresh := freshAddr.Elem()
err = walk(fresh, cri, cvt)
err = walk(fresh, cri)
if err != nil {
return errors.WithStack(err)
}
}
return nil
}
err = walk(val, typeTree, valTree)
err = walk(val, rootInfo)
if err != nil {
return errors.WithStack(err)
}
......
......@@ -24,20 +24,10 @@ func Test_PreloadEdgeCases(t *testing.T) {
withContext(t, models, func(conn *sqlite.Conn, c *hades.Context) {
// non-existent Bar
f := &Foo{ID: 1, BarID: 999}
wtest.Must(t, c.Preload(conn, &hades.PreloadParams{
Record: f,
Fields: []hades.PreloadField{
{Name: "Bar"},
},
}))
wtest.Must(t, c.Preload(conn, f, hades.Assoc("Bar")))
// empty slice
var foos []*Foo
wtest.Must(t, c.Preload(conn, &hades.PreloadParams{
Record: foos,
Fields: []hades.PreloadField{
{Name: "Bar"},
},
}))
wtest.Must(t, c.Preload(conn, foos, hades.Assoc("Bar")))
})
}
package hades
import (
"fmt"
"reflect"
"strings"
"github.com/go-xorm/builder"
"github.com/itchio/hades/sqliteutil2"
......@@ -14,35 +14,16 @@ import (
type AllEntities map[reflect.Type]EntityMap
type EntityMap []interface{}
type SaveParams struct {
// Record to save
Record interface{}
// Fields to save instead of the top-level record
Assocs []string
// For has_many and many_to_many, never delete rows for these models
DontCull []interface{}
}
func (c *Context) SaveOne(conn *sqlite.Conn, record interface{}) (err error) {
return c.SaveNoTransaction(conn, &SaveParams{
Record: record,
})
}
func (c *Context) Save(conn *sqlite.Conn, params *SaveParams) (err error) {
func (c *Context) Save(conn *sqlite.Conn, rec interface{}, opts ...SaveParam) (err error) {
defer sqliteutil2.Save(conn)(&err)
return c.SaveNoTransaction(conn, params)
return c.SaveNoTransaction(conn, rec, opts...)
}
func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error {
if params == nil {
return errors.New("Save: params cannot be nil")
func (c *Context) SaveNoTransaction(conn *sqlite.Conn, rec interface{}, opts ...SaveParam) error {
var params saveParams
for _, o := range opts {
o.ApplyToSaveParams(&params)
}
rec := params.Record
assocs := params.Assocs
val := reflect.ValueOf(rec)
valtyp := val.Type()
......@@ -50,13 +31,18 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error
valtyp = valtyp.Elem()
}
if valtyp.Kind() != reflect.Ptr {
return fmt.Errorf("Save expects a []*Model or a *Model, but it was passed a %v instead", val.Type())
return errors.Errorf("Save expects a []*Model or a *Model, but it was passed a %v instead", val.Type())
}
riMap := make(RecordInfoMap)
tree, err := c.WalkType(riMap, "<root>", valtyp, make(VisitMap), assocs)
rootField := &assocField{
name: "<root>",
mode: AssocModeAppend,
children: params.assocs,
}
rootRecordInfo, err := c.WalkType(riMap, rootField, valtyp)
if err != nil {
return errors.Wrap(err, "walking records to be saved")
return errors.WithMessage(err, "walking records to be saved")
}
entities := make(AllEntities)
......@@ -75,7 +61,7 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error
switch vri.Relationship.Kind {
case "has_many", "has_one":
if len(pri.ModelStruct.PrimaryFields) != 1 {
return fmt.Errorf("Since %v %s %v, we expected one primary key in %v, but found %d",
return errors.Errorf("Since %v %s %v, we expected one primary key in %v, but found %d",
p.Type(),
vri.Relationship.Kind,
v.Type(),
......@@ -85,7 +71,7 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error
}
pkField := p.Elem().FieldByName(pri.ModelStruct.PrimaryFields[0].Name)
if len(vri.Relationship.ForeignFieldNames) != 1 {
return fmt.Errorf("Since %v %s %v, we expected one foreign field in %v, but found %d",
return errors.Errorf("Since %v %s %v, we expected one foreign field in %v, but found %d",
p.Type(),
vri.Relationship.Kind,
v.Type(),
......@@ -97,7 +83,7 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error
fkField.Set(pkField)
case "belongs_to":
if len(vri.ModelStruct.PrimaryFields) != 1 {
return fmt.Errorf("Since %v %s %v, we expected one primary key in %v, but found %d",
return errors.Errorf("Since %v %s %v, we expected one primary key in %v, but found %d",
p.Type(),
vri.Relationship.Kind,
v.Type(),
......@@ -108,7 +94,7 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error
pkField := v.Elem().FieldByName(vri.ModelStruct.PrimaryFields[0].Name)
if len(vri.Relationship.ForeignFieldNames) != 1 {
return fmt.Errorf("Since %v %s %v, we expected one foreign field in %v, but found %d",
return errors.Errorf("Since %v %s %v, we expected one foreign field in %v, but found %d",
p.Type(),
vri.Relationship.Kind,
v.Type(),
......@@ -126,21 +112,21 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error
numVisited++
err := addEntity(v)
if err != nil {
return errors.Wrap(err, "adding entity")
return errors.WithMessage(err, "adding entity")
}
}
if v.Kind() != reflect.Ptr {
return fmt.Errorf("expected a pointer, but got with %v", v)
return errors.Errorf("expected a pointer, but got with %v", v)
}
vs := v.Elem()
if vs.Kind() != reflect.Struct {
return fmt.Errorf("expected a struct, but got with %v", v)
return errors.Errorf("expected a struct, but got with %v", v)
}
for _, childRi := range vri.Children {
child := vs.FieldByName(childRi.Name)
child := vs.FieldByName(childRi.Name())
if !child.IsValid() {
continue
}
......@@ -153,7 +139,7 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error
persistChildren := true
err := walk(v, vri, child, childRi, persistChildren)
if err != nil {
return errors.Wrap(err, "walking child entities to be saved")
return errors.WithMessage(err, "walking child entities to be saved")
}
}
return nil
......@@ -162,15 +148,11 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error
walk = func(p reflect.Value, pri *RecordInfo, v reflect.Value, vri *RecordInfo, persist bool) error {
if v.Kind() == reflect.Slice {
cull := false
if vri.Relationship != nil {
switch vri.Relationship.Kind {
case "has_many":
cull = true
for _, dc := range params.DontCull {
if reflect.TypeOf(dc).Elem() == vri.ModelStruct.ModelType {
cull = false
}
if vri.Field.Mode() == AssocModeReplace {
cull = true
}
case "many_to_many":
// culling is done later, but let's record the ManyToMany now
......@@ -181,7 +163,7 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error
for i := 0; i < v.Len(); i++ {
err := visit(p, pri, v.Index(i), vri, persist)
if err != nil {
return errors.Wrap(err, "walking slice of children")
return errors.WithMessage(err, "walking slice of children")
}
}
......@@ -196,7 +178,12 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error
parentPK := parentPF.Field
if len(vri.ModelStruct.PrimaryFields) != 1 {
return errors.Errorf("Since %v has_many %v", pri.Name, vri.Name)
var pfNames []string
for _, pf := range vri.ModelStruct.PrimaryFields {
pfNames = append(pfNames, pf.Name)
}
return errors.Errorf("Since %v has_many %v, expected %v to have one primary key. Instead, it has primary fields: %s",
pri.Name, vri.Name, strings.Join(pfNames, ", "))
}
valuePF := c.NewScope(v.Interface()).PrimaryField()
if valuePF == nil {
......@@ -247,30 +234,30 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error
} else {
err := visit(p, pri, v, vri, persist)
if err != nil {
return errors.Wrap(err, "walking single child")
return errors.WithMessage(err, "walking single child")
}
}
return nil
}
persistRoot := assocs == nil
err = walk(reflect.Zero(reflect.TypeOf(0)), nil, val, tree, persistRoot)
err = walk(reflect.Zero(reflect.TypeOf(0)), nil, val, rootRecordInfo, !params.omitRoot)
if err != nil {
return errors.Wrap(err, "walking all records to be persisted")
return errors.WithMessage(err, "walking all records to be persisted")
}
for _, m := range entities {
err := c.saveRows(conn, params, m)
for typ, m := range entities {
ri := riMap[typ]
err := c.saveRows(conn, ri.Field.Mode(), m)
if err != nil {
return errors.Wrap(err, "saving rows")
return errors.WithMessage(err, "saving rows")
}
}
for _, ri := range riMap {
if ri.ManyToMany != nil {
err := c.saveJoins(params, conn, ri.ManyToMany)
err := c.saveJoins(conn, ri.Field.Mode(), ri.ManyToMany)
if err != nil {
return errors.Wrap(err, "saving joins")
return errors.WithMessage(err, "saving joins")
}
}
}
......
......@@ -8,15 +8,7 @@ import (
"github.com/pkg/errors"
)
func (c *Context) saveJoins(params *SaveParams, conn *sqlite.Conn, mtm *ManyToMany) error {
cull := true
for _, dc := range params.DontCull {
if mtm.JoinTable == ToDBName(c.NewScope(dc).TableName()) {
cull = false
break
}
}
func (c *Context) saveJoins(conn *sqlite.Conn, mode AssocMode, mtm *ManyToMany) error {
joinType := reflect.PtrTo(mtm.Scope.GetModelStruct().ModelType)
getDestinKey := func(v reflect.Value) interface{} {
......@@ -28,7 +20,7 @@ func (c *Context) saveJoins(params *SaveParams, conn *sqlite.Conn, mtm *ManyToMa
err := c.Select(conn, cacheAddr.Interface(), builder.Eq{mtm.SourceDBName: sourceKey}, nil)
if err != nil {
return errors.Wrap(err, "fetching cached records to compare later")
return errors.WithMessage(err, "fetching cached records to compare later")
}
cache := cacheAddr.Elem()
......@@ -60,7 +52,7 @@ func (c *Context) saveJoins(params *SaveParams, conn *sqlite.Conn, mtm *ManyToMa
cf, err := DiffRecord(ifrec, icrec, mtm.Scope)
if err != nil {
return errors.Wrap(err, "diffing database records")
return errors.WithMessage(err, "diffing database records")
}
if cf != nil {
......@@ -78,14 +70,10 @@ func (c *Context) saveJoins(params *SaveParams, conn *sqlite.Conn, mtm *ManyToMa
}
}
if !cull {
// Not deleting extra join records, as requested
} else {
if len(deletes) > 0 {
err := c.deletePagedByPK(conn, mtm.JoinTable, mtm.DestinDBName, deletes, builder.Eq{mtm.SourceDBName: sourceKey})
if err != nil {
return errors.Wrap(err, "deleting extraneous relations")
}
if mode == AssocModeReplace && len(deletes) > 0 {
err := c.deletePagedByPK(conn, mtm.JoinTable, mtm.DestinDBName, deletes, builder.Eq{mtm.SourceDBName: sourceKey})
if err != nil {
return errors.WithMessage(err, "deleting extraneous relations")
}
}
......@@ -95,7 +83,7 @@ func (c *Context) saveJoins(params *SaveParams, conn *sqlite.Conn, mtm *ManyToMa
if rec.IsValid() {
err := c.Insert(conn, mtm.Scope, rec)
if err != nil {
return errors.Wrap(err, "creating new relation records")
return errors.WithMessage(err, "creating new relation records")
}
} else {
// if not passed an explicit record, make it ourselves
......@@ -117,7 +105,7 @@ func (c *Context) saveJoins(params *SaveParams, conn *sqlite.Conn, mtm *ManyToMa
query := builder.Update(cf.ToEq()).Into(mtm.Scope.TableName()).Where(builder.Eq{mtm.SourceDBName: sourceKey, mtm.DestinDBName: destinKey})
err := c.Exec(conn, query, nil)
if err != nil {
return errors.Wrap(err, "updating related records")
return errors.WithMessage(err, "updating related records")
}
}
}
......
package hades
import (
"fmt"
"math"
"reflect"
"strings"
......@@ -12,7 +11,7 @@ import (
"github.com/pkg/errors"
)
func (c *Context) saveRows(conn *sqlite.Conn, params *SaveParams, inputIface interface{}) error {
func (c *Context) saveRows(conn *sqlite.Conn, mode AssocMode, inputIface interface{}) error {
// inputIFace is a `[]interface{}`
input := reflect.ValueOf(inputIface)
if input.Kind() != reflect.Slice {
......@@ -40,7 +39,7 @@ func (c *Context) saveRows(conn *sqlite.Conn, params *SaveParams, inputIface int
// this will happen for associations
if len(primaryFields) != 1 {
if len(primaryFields) != 2 {
return fmt.Errorf("Have %d primary keys for %s, don't know what to do", len(primaryFields), modelName)
return errors.Errorf("Have %d primary keys for %s, don't know what to do", len(primaryFields), modelName)
}
recordsGroupedByPrimaryField := make(map[*Field]map[interface{}][]reflect.Value)
......@@ -69,7 +68,7 @@ func (c *Context) saveRows(conn *sqlite.Conn, params *SaveParams, inputIface int
}
if bestSourcePrimaryField == nil {
return fmt.Errorf("Have %d primary keys for %s, don't know what to do", len(primaryFields), modelName)
return errors.Errorf("Have %d primary keys for %s, don't know what to do", len(primaryFields), modelName)
}
var bestDestinPrimaryField *Field
......@@ -85,28 +84,28 @@ func (c *Context) saveRows(conn *sqlite.Conn, params *SaveParams, inputIface int
sourceRelField, ok := scope.FieldByName(strings.TrimSuffix(bestSourcePrimaryField.Name, "ID"))
if !ok {
return fmt.Errorf("Could not find assoc for %s.%s", modelName, bestSourcePrimaryField.Name)
return errors.Errorf("Could not find assoc for %s.%s", modelName, bestSourcePrimaryField.Name)
}
destinRelField, ok := scope.FieldByName(strings.TrimSuffix(bestDestinPrimaryField.Name, "ID"))
if !ok {
return fmt.Errorf("Could not find assoc for %s.%s", modelName, bestDestinPrimaryField.Name)
return errors.Errorf("Could not find assoc for %s.%s", modelName, bestDestinPrimaryField.Name)
}
sourceScope := c.ScopeMap.ByType(sourceRelField.Struct.Type)
if sourceScope == nil {
return fmt.Errorf("Could not find scope for assoc for %s.%s", modelName, bestSourcePrimaryField.Name)
return errors.Errorf("Could not find scope for assoc for %s.%s", modelName, bestSourcePrimaryField.Name)
}
destinScope := c.ScopeMap.ByType(destinRelField.Struct.Type)
if destinScope == nil {
return fmt.Errorf("Could not find scope for assoc for %s.%s", modelName, bestSourcePrimaryField.Name)
return errors.Errorf("Could not find scope for assoc for %s.%s", modelName, bestSourcePrimaryField.Name)
}
if len(sourceScope.PrimaryFields()) != 1 {
return fmt.Errorf("Expected Source model %s to have 1 primary field, but it has %d",
return errors.Errorf("Expected Source model %s to have 1 primary field, but it has %d",
sourceScope.GetModelStruct().ModelType, len(sourceScope.PrimaryFields()))
}
if len(destinScope.PrimaryFields()) != 1 {
return fmt.Errorf("Expected Destin model %s to have 1 primary field, but it has %d",
return errors.Errorf("Expected Destin model %s to have 1 primary field, but it has %d",
destinScope.GetModelStruct().ModelType, len(destinScope.PrimaryFields()))
}
......@@ -126,7 +125,7 @@ func (c *Context) saveRows(conn *sqlite.Conn, params *SaveParams, inputIface int
[]JoinTableForeignKey{destinJTFK},
)
if err != nil {
return errors.Wrap(err, "creating ManyToMany relationship")
return errors.WithMessage(err, "creating ManyToMany relationship")
}
for sourceKey, recs := range valueMap {
......@@ -136,9 +135,9 @@ func (c *Context) saveRows(conn *sqlite.Conn, params *SaveParams, inputIface int
}
}
err = c.saveJoins(params, conn, mtm)
err = c.saveJoins(conn, mode, mtm)
if err != nil {
return errors.Wrap(err, "saving joins")
return errors.WithMessage(err, "saving joins")
}
return nil
......@@ -164,7 +163,7 @@ func (c *Context) saveRows(conn *sqlite.Conn, params *SaveParams, inputIface int
cacheAddr, err := c.fetchPagedByPK(conn, primaryField.DBName, keys, fresh.Type(), nil)
if err != nil {
return errors.Wrap(err, "getting existing rows")
return errors.WithMessage(err, "getting existing rows")
}
cache := cacheAddr.Elem()
......@@ -199,7 +198,7 @@ func (c *Context) saveRows(conn *sqlite.Conn, params *SaveParams, inputIface int
cf, err := DiffRecord(ifrec, icrec, scope)
if err != nil {
return errors.Wrap(err, "diffing db records")
return errors.WithMessage(err, "diffing db records")
}
if cf != nil {
......@@ -218,7 +217,7 @@ func (c *Context) saveRows(conn *sqlite.Conn, params *SaveParams, inputIface int
for _, rec := range inserts {
err := c.Insert(conn, scope, rec)
if err != nil {
return errors.Wrap(err, "inserting new DB records")
return errors.WithMessage(err, "inserting new DB records")
}
}
}
......@@ -227,7 +226,7 @@ func (c *Context) saveRows(conn *sqlite.Conn, params *SaveParams, inputIface int
eq := cf.ToEq()
err := c.Exec(conn, builder.Update(eq).Into(scope.TableName()).Where(builder.Eq{primaryField.DBName: key}), nil)
if err != nil {
return errors.Wrap(err, "updating DB records")
return errors.WithMessage(err, "updating DB records")
}
}
......
package hades_test
import (
"testing"
"crawshaw.io/sqlite"
"github.com/itchio/hades"
"github.com/itchio/wharf/wtest"
)
func Test_Save(t *testing.T) {
type Game struct {
ID int64
Title string
}
type CollectionGame struct {
ProfileID int64 `hades:"primary_key"`
GameID int64 `hades:"primary_key"`
}
type Profile struct {
ID int64
CollectionGames []*CollectionGame
}
models := []interface{}{
&Game{},
&CollectionGame{},
&Profile{},
}
withContext(t, models, func(conn *sqlite.Conn, c *hades.Context) {
p := &Profile{
ID: 1,
}
wtest.Must(t, c.Save(conn, p))
})
}
......@@ -28,7 +28,7 @@ func Test_Scan(t *testing.T) {
&GameEmbedData{},
}
withContext(t, models, func(conn *sqlite.Conn, c *hades.Context) {
wtest.Must(t, c.SaveOne(conn, []*Game{
wtest.Must(t, c.Save(conn, []*Game{
&Game{
ID: 24,
Title: "Jazz Jackrabbit",
......@@ -45,7 +45,7 @@ func Test_Scan(t *testing.T) {
Height: 240,
},
},
}))
}, hades.Assoc("EmbedData")))
var rows []struct {
Game `hades:"squash"`
......
......@@ -49,3 +49,13 @@ func (sm *ScopeMap) ByDBName(dbname string) *Scope {
func (sm *ScopeMap) ByType(typ reflect.Type) *Scope {
return sm.byType[typ]
}
func (sm *ScopeMap) Each(f func(*Scope) error) error {
for _, scope := range sm.byDBName {
err := f(scope)
if err != nil {
return err
}
}
return nil
}
......@@ -20,6 +20,7 @@ import (
"strings"