...
 
Commits (3)
...@@ -2,6 +2,7 @@ package hades_test ...@@ -2,6 +2,7 @@ package hades_test
import ( import (
"context" "context"
"fmt"
"testing" "testing"
"time" "time"
...@@ -46,7 +47,7 @@ func Test_AutoMigrate(t *testing.T) { ...@@ -46,7 +47,7 @@ func Test_AutoMigrate(t *testing.T) {
assert.False(t, pti[1].PrimaryKey) assert.False(t, pti[1].PrimaryKey)
assert.False(t, pti[1].NotNull) assert.False(t, pti[1].NotNull)
ordie(c.SaveOne(conn, &User{ID: 123, FirstName: "Joanna"})) ordie(c.Save(conn, &User{ID: 123, FirstName: "Joanna"}))
u := &User{} u := &User{}
foundUser, err := c.SelectOne(conn, u, builder.Eq{"id": 123}) foundUser, err := c.SelectOne(conn, u, builder.Eq{"id": 123})
ordie(err) ordie(err)
...@@ -193,7 +194,7 @@ func Test_AutoMigrateAllValidTypes(t *testing.T) { ...@@ -193,7 +194,7 @@ func Test_AutoMigrateAllValidTypes(t *testing.T) {
FirstName: "Jeremy", FirstName: "Jeremy",
HeartRate: 3.14, HeartRate: 3.14,
} }
ordie(c.SaveOne(conn, h1)) ordie(c.Save(conn, h1))
h2 := &Humanoid{} h2 := &Humanoid{}
found, err := c.SelectOne(conn, h2, builder.Eq{"id": 12}) found, err := c.SelectOne(conn, h2, builder.Eq{"id": 12})
...@@ -269,6 +270,6 @@ func Test_AutoMigrateSquash(t *testing.T) { ...@@ -269,6 +270,6 @@ func Test_AutoMigrateSquash(t *testing.T) {
func ordie(err error) { func ordie(err error) {
if err != nil { if err != nil {
panic(err) panic(fmt.Sprintf("%+v", err))
} }
} }
...@@ -36,22 +36,18 @@ func Test_BelongsTo(t *testing.T) { ...@@ -36,22 +36,18 @@ func Test_BelongsTo(t *testing.T) {
Desc: "Consumer-grade flamethrowers", Desc: "Consumer-grade flamethrowers",
} }
t.Log("Saving one fate") t.Log("Saving one fate")
wtest.Must(t, c.SaveOne(conn, someFate)) wtest.Must(t, c.Save(conn, someFate))
lea := &Human{ lea := &Human{
ID: 3, ID: 3,
FateID: someFate.ID, FateID: someFate.ID,
} }
t.Log("Saving one human") t.Log("Saving one human")
wtest.Must(t, c.SaveOne(conn, lea)) wtest.Must(t, c.Save(conn, lea))
t.Log("Preloading lea") t.Log("Preloading lea")
c.Preload(conn, &hades.PreloadParams{ c.Preload(conn, lea, hades.Assoc("Fate"))
Record: lea,
Fields: []hades.PreloadField{
{Name: "Fate"},
},
})
assert.NotNil(t, lea.Fate) assert.NotNil(t, lea.Fate)
assert.EqualValues(t, someFate.Desc, lea.Fate.Desc) assert.EqualValues(t, someFate.Desc, lea.Fate.Desc)
}) })
...@@ -64,10 +60,7 @@ func Test_BelongsTo(t *testing.T) { ...@@ -64,10 +60,7 @@ func Test_BelongsTo(t *testing.T) {
Desc: "Book authorship", Desc: "Book authorship",
}, },
} }
c.Save(conn, &hades.SaveParams{ wtest.Must(t, c.Save(conn, lea, hades.Assoc("Fate")))
Record: lea,
Assocs: []string{"Fate"},
})
fate := &Fate{} fate := &Fate{}
found, err := c.SelectOne(conn, fate, builder.Eq{"id": 421}) found, err := c.SelectOne(conn, fate, builder.Eq{"id": 421})
...@@ -81,27 +74,21 @@ func Test_BelongsTo(t *testing.T) { ...@@ -81,27 +74,21 @@ func Test_BelongsTo(t *testing.T) {
ID: 3, ID: 3,
Desc: "Space rodeo", Desc: "Space rodeo",
} }
wtest.Must(t, c.SaveOne(conn, fate)) wtest.Must(t, c.Save(conn, fate))
human := &Human{ human := &Human{
ID: 6, ID: 6,
FateID: 3, FateID: 3,
} }
wtest.Must(t, c.SaveOne(conn, human)) wtest.Must(t, c.Save(conn, human))
joke := &Joke{ joke := &Joke{
ID: "neuf", ID: "neuf",
HumanID: 6, HumanID: 6,
} }
wtest.Must(t, c.SaveOne(conn, joke)) wtest.Must(t, c.Save(conn, joke))
c.Preload(conn, &hades.PreloadParams{ c.Preload(conn, joke, hades.Assoc("Human", hades.Assoc("Fate")))
Record: joke,
Fields: []hades.PreloadField{
{Name: "Human"},
{Name: "Human.Fate"},
},
})
assert.NotNil(t, joke.Human) assert.NotNil(t, joke.Human)
assert.NotNil(t, joke.Human.Fate) assert.NotNil(t, joke.Human.Fate)
assert.EqualValues(t, "Space rodeo", joke.Human.Fate.Desc) assert.EqualValues(t, "Space rodeo", joke.Human.Fate.Desc)
......
...@@ -37,7 +37,7 @@ func Test_Delete(t *testing.T) { ...@@ -37,7 +37,7 @@ func Test_Delete(t *testing.T) {
var count int64 var count int64
var err error var err error
wtest.Must(t, c.SaveOne(conn, stories)) wtest.Must(t, c.Save(conn, stories))
count, err = c.Count(conn, &Story{}, builder.NewCond()) count, err = c.Count(conn, &Story{}, builder.NewCond())
wtest.Must(t, err) wtest.Must(t, err)
......
...@@ -121,7 +121,7 @@ func iseq(sf *StructField, v1f reflect.Value, v2f reflect.Value) (bool, error) { ...@@ -121,7 +121,7 @@ func iseq(sf *StructField, v1f reflect.Value, v2f reflect.Value) (bool, error) {
func (cf ChangedFields) ToEq() builder.Eq { func (cf ChangedFields) ToEq() builder.Eq {
eq := make(builder.Eq) eq := make(builder.Eq)
for sf, v := range cf { for sf, v := range cf {
eq[sf.DBName] = DBValue(v) eq[EscapeIdentifier(sf.DBName)] = DBValue(v)
} }
return eq return eq
} }
package hades package hades
import ( import (
"fmt"
"reflect" "reflect"
"github.com/pkg/errors" "github.com/pkg/errors"
...@@ -44,7 +43,7 @@ func (field *Field) Set(value interface{}) (err error) { ...@@ -44,7 +43,7 @@ func (field *Field) Set(value interface{}) (err error) {
if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
fieldValue.Set(reflectValue.Convert(fieldValue.Type())) fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
} else { } else {
err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type()) err = errors.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type())
} }
} }
} else { } else {
......
...@@ -32,8 +32,13 @@ func withContext(t *testing.T, models []interface{}, f WithContextFunc) { ...@@ -32,8 +32,13 @@ func withContext(t *testing.T, models []interface{}, f WithContextFunc) {
wtest.Must(t, err) wtest.Must(t, err)
c.Log = true c.Log = true
// wtest.Must(t, c.AutoMigrate(conn)) wtest.Must(t, c.AutoMigrate(conn))
c.AutoMigrate(conn)
defer func() {
c.ScopeMap.Each(func(scope *hades.Scope) error {
return c.ExecRaw(conn, "DROP TABLE "+scope.TableName(), nil)
})
}()
f(conn, c) f(conn, c)
} }
...@@ -43,12 +43,12 @@ func Test_HasMany(t *testing.T) { ...@@ -43,12 +43,12 @@ func Test_HasMany(t *testing.T) {
{ID: 11, Label: "Ability to not repeat oneself"}, {ID: 11, Label: "Ability to not repeat oneself"},
}, },
} }
wtest.Must(t, c.Save(conn, &hades.SaveParams{Record: p1})) wtest.Must(t, c.Save(conn, p1, hades.Assoc("Qualities")))
assertCount(&Programmer{}, 1) assertCount(&Programmer{}, 1)
assertCount(&Quality{}, 3) assertCount(&Quality{}, 3)
p1.Qualities[2].Label = "Inspiration again" p1.Qualities[2].Label = "Inspiration again"
wtest.Must(t, c.Save(conn, &hades.SaveParams{Record: p1})) wtest.Must(t, c.Save(conn, p1, hades.Assoc("Qualities")))
assertCount(&Programmer{}, 1) assertCount(&Programmer{}, 1)
assertCount(&Quality{}, 3) assertCount(&Quality{}, 3)
{ {
...@@ -67,45 +67,35 @@ func Test_HasMany(t *testing.T) { ...@@ -67,45 +67,35 @@ func Test_HasMany(t *testing.T) {
}, },
} }
programmers := []*Programmer{p1, p2} programmers := []*Programmer{p1, p2}
wtest.Must(t, c.Save(conn, &hades.SaveParams{Record: programmers})) wtest.Must(t, c.Save(conn, programmers, hades.Assoc("Qualities")))
assertCount(&Programmer{}, 2) assertCount(&Programmer{}, 2)
assertCount(&Quality{}, 5) assertCount(&Quality{}, 5)
p1bis := &Programmer{ID: 3} p1bis := &Programmer{ID: 3}
pp := &hades.PreloadParams{ wtest.Must(t, c.Preload(conn, p1bis, hades.Assoc("Qualities")))
Record: p1bis,
Fields: []hades.PreloadField{
{Name: "Qualities"},
},
}
wtest.Must(t, c.Preload(conn, pp))
assert.EqualValues(t, 3, len(p1bis.Qualities), "preload has_many") assert.EqualValues(t, 3, len(p1bis.Qualities), "preload has_many")
wtest.Must(t, c.Preload(conn, pp)) wtest.Must(t, c.Preload(conn, p1bis, hades.Assoc("Qualities")))
assert.EqualValues(t, 3, len(p1bis.Qualities), "preload replaces, doesn't append") assert.EqualValues(t, 3, len(p1bis.Qualities), "preload replaces, doesn't append")
pp.Fields[0] = hades.PreloadField{ wtest.Must(t, c.Preload(conn, p1bis,
Name: "Qualities", hades.AssocWithSearch("Qualities", hades.Search().OrderBy("id ASC"))),
Search: hades.Search().OrderBy("id asc"), )
}
wtest.Must(t, c.Preload(conn, pp))
assert.EqualValues(t, "Inspiration", p1bis.Qualities[0].Label, "orders by (asc)") assert.EqualValues(t, "Inspiration", p1bis.Qualities[0].Label, "orders by (asc)")
pp.Fields[0] = hades.PreloadField{ wtest.Must(t, c.Preload(conn, p1bis,
Name: "Qualities", hades.AssocWithSearch("Qualities", hades.Search().OrderBy("id DESC"))),
Search: hades.Search().OrderBy("id desc"), )
}
wtest.Must(t, c.Preload(conn, pp))
assert.EqualValues(t, "Inspiration again", p1bis.Qualities[0].Label, "orders by (desc)") assert.EqualValues(t, "Inspiration again", p1bis.Qualities[0].Label, "orders by (desc)")
// no fields // no fields
assert.Error(t, c.Preload(conn, &hades.PreloadParams{Record: p1bis})) assert.Error(t, c.Preload(conn, p1bis))
// not a model // not a model
assert.Error(t, c.Preload(conn, &hades.PreloadParams{Record: 42, Fields: pp.Fields})) assert.Error(t, c.Preload(conn, 42, hades.Assoc("Qualities")))
// non-existent relation // non-existent relation
assert.Error(t, c.Preload(conn, &hades.PreloadParams{Record: p1bis, Fields: []hades.PreloadField{{Name: "Woops"}}})) assert.Error(t, c.Preload(conn, 42, hades.Assoc("Woops")))
}) })
} }
...@@ -154,10 +144,7 @@ func Test_HasManyThorough(t *testing.T) { ...@@ -154,10 +144,7 @@ func Test_HasManyThorough(t *testing.T) {
t.Logf("...snip tons of INSERT...") t.Logf("...snip tons of INSERT...")
c.Log = false c.Log = false
ordie(c.Save(conn, &hades.SaveParams{ ordie(c.Save(conn, car, hades.Assoc("Traits")))
Record: car,
Assocs: []string{"Traits"},
}))
c.Log = true c.Log = true
numTraits := len(car.Traits) numTraits := len(car.Traits)
...@@ -168,20 +155,13 @@ func Test_HasManyThorough(t *testing.T) { ...@@ -168,20 +155,13 @@ func Test_HasManyThorough(t *testing.T) {
car.Traits = nil car.Traits = nil
ordie(c.Save(conn, &hades.SaveParams{ ordie(c.Save(conn, car, hades.Assoc("Traits")))
Record: car,
Assocs: []string{"Traits"},
DontCull: []interface{}{&Trait{}},
}))
traitCount, err = c.Count(conn, &Trait{}, builder.NewCond()) traitCount, err = c.Count(conn, &Trait{}, builder.NewCond())
ordie(err) ordie(err)
assert.EqualValues(t, numTraits, traitCount, "traits should still exist after partial-join save") assert.EqualValues(t, numTraits, traitCount, "traits should still exist after partial-join save")
ordie(c.Save(conn, &hades.SaveParams{ ordie(c.Save(conn, car, hades.AssocReplace("Traits")))
Record: car,
Assocs: []string{"Traits"},
}))
traitCount, err = c.Count(conn, &Trait{}, builder.NewCond()) traitCount, err = c.Count(conn, &Trait{}, builder.NewCond())
ordie(err) ordie(err)
......
...@@ -51,12 +51,12 @@ func Test_HasOne(t *testing.T) { ...@@ -51,12 +51,12 @@ func Test_HasOne(t *testing.T) {
assert.EqualValues(t, expectedCount, count) assert.EqualValues(t, expectedCount, count)
} }
wtest.Must(t, c.Save(conn, &hades.SaveParams{Record: country, Assocs: []string{"Specialty"}})) wtest.Must(t, c.Save(conn, country, hades.OmitRoot(), hades.Assoc("Specialty", hades.Assoc("Drawback"))))
assertCount(&Country{}, 0) assertCount(&Country{}, 0)
assertCount(&Specialty{}, 1) assertCount(&Specialty{}, 1)
assertCount(&Drawback{}, 1) assertCount(&Drawback{}, 1)
wtest.Must(t, c.Save(conn, &hades.SaveParams{Record: country})) wtest.Must(t, c.Save(conn, country, hades.Assoc("Specialty", hades.Assoc("Drawback"))))
assertCount(&Country{}, 1) assertCount(&Country{}, 1)
assertCount(&Specialty{}, 1) assertCount(&Specialty{}, 1)
assertCount(&Drawback{}, 1) assertCount(&Drawback{}, 1)
...@@ -70,12 +70,8 @@ func Test_HasOne(t *testing.T) { ...@@ -70,12 +70,8 @@ func Test_HasOne(t *testing.T) {
countries = append(countries, country) countries = append(countries, country)
} }
wtest.Must(t, c.Preload(conn, &hades.PreloadParams{ wtest.Must(t, c.Preload(conn, countries,
Record: countries, hades.Assoc("Specialty",
Fields: []hades.PreloadField{ hades.Assoc("Drawback"))))
{Name: "Specialty"},
{Name: "Specialty.Drawback"},
},
}))
}) })
} }
...@@ -32,7 +32,7 @@ func (scope *Scope) ToEq(rec reflect.Value) builder.Eq { ...@@ -32,7 +32,7 @@ func (scope *Scope) ToEq(rec reflect.Value) builder.Eq {
if !sf.IsNormal { if !sf.IsNormal {
return return
} }
eq[sf.DBName] = DBValue(field.Interface()) eq[EscapeIdentifier(sf.DBName)] = DBValue(field.Interface())
} }
for _, sf := range scope.GetModelStruct().StructFields { for _, sf := range scope.GetModelStruct().StructFields {
......
...@@ -39,9 +39,7 @@ func Test_ManyToMany(t *testing.T) { ...@@ -39,9 +39,7 @@ func Test_ManyToMany(t *testing.T) {
}, },
} }
t.Logf("saving just fr") t.Logf("saving just fr")
wtest.Must(t, c.Save(conn, &hades.SaveParams{ wtest.Must(t, c.Save(conn, fr, hades.Assoc("Words")))
Record: fr,
}))
assertCount := func(model interface{}, expectedCount int64) { assertCount := func(model interface{}, expectedCount int64) {
t.Helper() t.Helper()
...@@ -62,10 +60,7 @@ func Test_ManyToMany(t *testing.T) { ...@@ -62,10 +60,7 @@ func Test_ManyToMany(t *testing.T) {
}, },
} }
t.Logf("saving fr+en") t.Logf("saving fr+en")
wtest.Must(t, c.Save(conn, &hades.SaveParams{ wtest.Must(t, c.Save(conn, []*Language{fr, en}, hades.Assoc("Words")))
Record: []*Language{fr, en},
}))
assertCount(&Language{}, 2) assertCount(&Language{}, 2)
assertCount(&Word{}, 2) assertCount(&Word{}, 2)
assertCount(&LanguageWord{}, 4) assertCount(&LanguageWord{}, 4)
...@@ -75,19 +70,14 @@ func Test_ManyToMany(t *testing.T) { ...@@ -75,19 +70,14 @@ func Test_ManyToMany(t *testing.T) {
{ID: "Wreck"}, {ID: "Wreck"},
{ID: "Nervous"}, {ID: "Nervous"},
} }
wtest.Must(t, c.Save(conn, &hades.SaveParams{ wtest.Must(t, c.Save(conn, []*Language{en}, hades.Assoc("Words")))
Record: []*Language{en},
DontCull: []interface{}{&LanguageWord{}},
}))
assertCount(&Language{}, 2) assertCount(&Language{}, 2)
assertCount(&Word{}, 4) assertCount(&Word{}, 4)
assertCount(&LanguageWord{}, 6) assertCount(&LanguageWord{}, 6)
t.Logf("replacing all english words") t.Logf("replacing all english words")
wtest.Must(t, c.Save(conn, &hades.SaveParams{ wtest.Must(t, c.Save(conn, []*Language{en}, hades.AssocReplace("Words")))
Record: []*Language{en},
}))
assertCount(&Language{}, 2) assertCount(&Language{}, 2)
assertCount(&Word{}, 4) assertCount(&Word{}, 4)
...@@ -95,10 +85,7 @@ func Test_ManyToMany(t *testing.T) { ...@@ -95,10 +85,7 @@ func Test_ManyToMany(t *testing.T) {
t.Logf("adding commentary") t.Logf("adding commentary")
en.Words[0].Comment = "punk band reference" en.Words[0].Comment = "punk band reference"
wtest.Must(t, c.Save(conn, &hades.SaveParams{ wtest.Must(t, c.Save(conn, []*Language{en}, hades.Assoc("Words")))
Record: []*Language{en},
}))
assertCount(&Language{}, 2) assertCount(&Language{}, 2)
assertCount(&Word{}, 4) assertCount(&Word{}, 4)
assertCount(&LanguageWord{}, 4) assertCount(&LanguageWord{}, 4)
...@@ -115,12 +102,7 @@ func Test_ManyToMany(t *testing.T) { ...@@ -115,12 +102,7 @@ func Test_ManyToMany(t *testing.T) {
{ID: fr.ID}, {ID: fr.ID},
{ID: en.ID}, {ID: en.ID},
} }
err := c.Preload(conn, &hades.PreloadParams{ err := c.Preload(conn, langs, hades.Assoc("Words"))
Record: langs,
Fields: []hades.PreloadField{
{Name: "Words"},
},
})
// many_to_many preload is not implemented // many_to_many preload is not implemented
assert.Error(t, err) assert.Error(t, err)
}) })
...@@ -179,9 +161,29 @@ func Test_ManyToManyRevenge(t *testing.T) { ...@@ -179,9 +161,29 @@ func Test_ManyToManyRevenge(t *testing.T) {
} }
} }
p := makeProfile() p := makeProfile()
c.Save(conn, &hades.SaveParams{ wtest.Must(t, c.Save(conn, p,
Record: p, hades.Assoc("ProfileGames",
}) hades.Assoc("Game"),
),
))
var names []struct {
Name string
}
wtest.Must(t, c.ExecWithSearch(conn,
builder.Select("games.title").
From("games").
LeftJoin("profile_games", builder.Expr("profile_games.game_id = games.id")),
hades.Search().OrderBy("profile_games.\"order\" ASC"),
c.IntoRowsScanner(&names),
))
assert.EqualValues(t, []struct {
Name string
}{
{"First offensive"},
{"Seconds until midnight"},
{"Three was company"},
}, names)
}) })
} }
...@@ -237,7 +239,7 @@ func Test_ManyToManyThorough(t *testing.T) { ...@@ -237,7 +239,7 @@ func Test_ManyToManyThorough(t *testing.T) {
{ {
beforeSaveQueryCount := c.QueryCount beforeSaveQueryCount := c.QueryCount
ordie(c.SaveOne(conn, p)) ordie(c.Save(conn, p, hades.Assoc("Authors")))
pieceSelect := 1 pieceSelect := 1
pieceInsert := 1 pieceInsert := 1
...@@ -271,7 +273,7 @@ func Test_ManyToManyThorough(t *testing.T) { ...@@ -271,7 +273,7 @@ func Test_ManyToManyThorough(t *testing.T) {
{ {
beforeSaveQueryCount := c.QueryCount beforeSaveQueryCount := c.QueryCount
ordie(c.SaveOne(conn, p)) ordie(c.Save(conn, p, hades.AssocReplace("Authors")))
pieceSelect := 1 pieceSelect := 1
...@@ -297,7 +299,7 @@ func Test_ManyToManyThorough(t *testing.T) { ...@@ -297,7 +299,7 @@ func Test_ManyToManyThorough(t *testing.T) {
{ {
beforeSaveQueryCount := c.QueryCount beforeSaveQueryCount := c.QueryCount
ordie(c.SaveOne(conn, p)) ordie(c.Save(conn, p, hades.AssocReplace("Authors")))
pieceSelect := 1 pieceSelect := 1
...@@ -328,7 +330,7 @@ func Test_ManyToManyThorough(t *testing.T) { ...@@ -328,7 +330,7 @@ func Test_ManyToManyThorough(t *testing.T) {
{ {
beforeSaveQueryCount := c.QueryCount beforeSaveQueryCount := c.QueryCount
ordie(c.SaveOne(conn, p)) ordie(c.Save(conn, p, hades.AssocReplace("Authors")))
pieceSelect := 1 pieceSelect := 1
...@@ -358,14 +360,14 @@ func Test_ManyToManyThorough(t *testing.T) { ...@@ -358,14 +360,14 @@ func Test_ManyToManyThorough(t *testing.T) {
}) })
} }
ordie(c.SaveOne(conn, p)) ordie(c.Save(conn, p, hades.AssocReplace("Authors")))
assertCount(&Piece{}, 1) assertCount(&Piece{}, 1)
assertCount(&Author{}, len(originalAuthors)+1+1200) assertCount(&Author{}, len(originalAuthors)+1+1200)
assertCount(&PieceAuthor{}, len(p.Authors)) assertCount(&PieceAuthor{}, len(p.Authors))
p.Authors = nil p.Authors = nil
ordie(c.SaveOne(conn, p)) ordie(c.Save(conn, p, hades.AssocReplace("Authors")))
assertCount(&Piece{}, 1) assertCount(&Piece{}, 1)
assertCount(&Author{}, len(originalAuthors)+1+1200) assertCount(&Author{}, len(originalAuthors)+1+1200)
......
...@@ -52,7 +52,7 @@ func Test_Null(t *testing.T) { ...@@ -52,7 +52,7 @@ func Test_Null(t *testing.T) {
ID: 123, ID: 123,
} }
ordie(c.SaveOne(conn, d)) ordie(c.Save(conn, d))
{ {
dd := &Download{} dd := &Download{}
found, err := c.SelectOne(conn, dd, builder.Eq{"id": 123}) found, err := c.SelectOne(conn, dd, builder.Eq{"id": 123})
...@@ -74,7 +74,7 @@ func Test_Null(t *testing.T) { ...@@ -74,7 +74,7 @@ func Test_Null(t *testing.T) {
finishedAt := time.Now() finishedAt := time.Now()
d.FinishedAt = &finishedAt d.FinishedAt = &finishedAt
ordie(c.SaveOne(conn, d)) ordie(c.Save(conn, d))
{ {
dd := &Download{} dd := &Download{}
...@@ -89,7 +89,7 @@ func Test_Null(t *testing.T) { ...@@ -89,7 +89,7 @@ func Test_Null(t *testing.T) {
} }
d.ErrorMessage = nil d.ErrorMessage = nil
ordie(c.SaveOne(conn, d)) ordie(c.Save(conn, d))
{ {
dd := &Download{} dd := &Download{}
......
...@@ -31,11 +31,11 @@ func (c *Context) fetchPagedByPK(conn *sqlite.Conn, PKDBName string, keys []inte ...@@ -31,11 +31,11 @@ func (c *Context) fetchPagedByPK(conn *sqlite.Conn, PKDBName string, keys []inte
} }
pageAddr := reflect.New(sliceType) pageAddr := reflect.New(sliceType)
cond := builder.In(PKDBName, remainingItems[:pageSize]...) cond := builder.In(EscapeIdentifier(PKDBName), remainingItems[:pageSize]...)
err := c.Select(conn, pageAddr.Interface(), cond, search) err := c.Select(conn, pageAddr.Interface(), cond, search)
if err != nil { if err != nil {
return result, errors.Wrap(err, "performing page fetch") return result, errors.WithMessage(err, "performing page fetch")
} }
appended := reflect.AppendSlice(resultVal, pageAddr.Elem()) appended := reflect.AppendSlice(resultVal, pageAddr.Elem())
......
package hades
type AssocMode int
const (
AssocModeAppend AssocMode = iota
AssocModeReplace
)
type assocField struct {
name string
search *SearchParams
mode AssocMode
children []AssocField
}
type saveParams struct {
assocs []AssocField
omitRoot bool
}
type preloadParams struct {
assocs []AssocField
}
type SaveParam interface {
ApplyToSaveParams(sp *saveParams)
}
type PreloadParam interface {
ApplyToPreloadParams(pp *preloadParams)
}
type AssocField interface {
SaveParam
PreloadParam
Name() string
Mode() AssocMode
Search() *SearchParams
Children() []AssocField
}
// -------------
// OmitRoot tells save to not save the record passed,
// but only associations
func OmitRoot() SaveParam {
return &omitRoot{}
}
type omitRoot struct{}
func (o *omitRoot) ApplyToSaveParams(sp *saveParams) {
sp.omitRoot = true
}
// Assoc tells save to save the specified association,
// but not to remove any existing associated records, even if
// they're not listed anymore
func Assoc(fieldName string, children ...AssocField) AssocField {
return &assocField{
name: fieldName,
mode: AssocModeAppend,
children: children,
}
}
// AssocReplace tells save to save the specified assocation,
// and to remove any associated records that are no longer listed
func AssocReplace(fieldName string, children ...AssocField) AssocField {
return &assocField{
name: fieldName,
mode: AssocModeReplace,
children: children,
}
}
func AssocWithSearch(fieldName string, search *SearchParams, children ...AssocField) AssocField {
return &assocField{
name: fieldName,
mode: AssocModeAppend,
search: search,
children: children,
}
}
func (f *assocField) ApplyToSaveParams(sp *saveParams) {
sp.assocs = append(sp.assocs, f)
}
func (f *assocField) ApplyToPreloadParams(pp *preloadParams) {
pp.assocs = append(pp.assocs, f)
}
func (f *assocField) Name() string {
return f.name
}
func (f *assocField) Mode() AssocMode {
return f.mode
}
func (f *assocField) Children() []AssocField {
return f.children
}
func (f *assocField) Search() *SearchParams {
return f.search
}
...@@ -3,73 +3,19 @@ package hades ...@@ -3,73 +3,19 @@ package hades
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"strings"
"crawshaw.io/sqlite" "crawshaw.io/sqlite"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
type PreloadParams struct { func (c *Context) Preload(conn *sqlite.Conn, rec interface{}, opts ...PreloadParam) error {
Record interface{} params := &preloadParams{}
for _, o := range opts {
// Fields to preload, for example []string{"CollectionGames", "CollectionGames.Game"} o.ApplyToPreloadParams(params)
Fields []PreloadField
}
type PreloadField struct {
Name string
Search *SearchParams
}
type Node struct {
Name string
Search *SearchParams
Field PreloadField
Children map[string]*Node
}
func NewNode(name string) *Node {
return &Node{
Name: name,
Children: make(map[string]*Node),
} }
}
func (n *Node) String() string { if len(params.assocs) == 0 {
var res []string return errors.Errorf("Cannot preload 0 assocs")
res = append(res, fmt.Sprintf("- %s%s", n.Name, n.Search))
for _, c := range n.Children {
for _, cl := range strings.Split(c.String(), "\n") {
res = append(res, " "+cl)
}
}
return strings.Join(res, "\n")
}
func (n *Node) Add(pf PreloadField) {
tokens := strings.Split(pf.Name, ".")
name := tokens[0]
c, ok := n.Children[name]
if !ok {
c = NewNode(name)
n.Children[name] = c
}
if len(tokens) > 1 {
pfc := pf
pfc.Name = strings.Join(tokens[1:], ".")
c.Add(pfc)
} else {
c.Field = pf
c.Search = pf.Search
}
}
func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error {
rec := params.Record
if len(params.Fields) == 0 {
return errors.New("Preload expects a non-empty list in Fields")
} }
val := reflect.ValueOf(rec) val := reflect.ValueOf(rec)
...@@ -81,43 +27,31 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error { ...@@ -81,43 +27,31 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error {
valtyp = valtyp.Elem() valtyp = valtyp.Elem()
} }
if valtyp.Kind() != reflect.Ptr { if valtyp.Kind() != reflect.Ptr {
return fmt.Errorf("Preload expects a []*Model or *Model, but it was passed a %v instead", val.Type()) return errors.Errorf("Preload expects a []*Model or *Model, but it was passed a %v instead", val.Type())
} }
riMap := make(RecordInfoMap) riMap := make(RecordInfoMap)
rootName := fmt.Sprintf("%v", valtyp) rootField := &assocField{
typeTree, err := c.WalkType(riMap, rootName, valtyp, make(VisitMap), nil) name: fmt.Sprintf("%v", valtyp),
if err != nil { mode: AssocModeAppend,
return errors.Wrap(err, "waking type tree") children: params.assocs,
} }
rootInfo, err := c.WalkType(riMap, rootField, valtyp)
valTree := NewNode(rootName) if err != nil {
for _, field := range params.Fields { return errors.WithMessage(err, "waking type tree")
valTree.Add(field)
} }
var walk func(p reflect.Value, pri *RecordInfo, pvt *Node) error var walk func(p reflect.Value, pri *RecordInfo) error
walk = func(p reflect.Value, pri *RecordInfo, pvt *Node) error { walk = func(p reflect.Value, pri *RecordInfo) error {
for _, cvt := range pvt.Children { ptyp := p.Type()
var cri *RecordInfo if ptyp.Kind() == reflect.Slice {
for _, c := range pri.Children { ptyp = ptyp.Elem()
if c.Name == cvt.Name { }
cri = c if ptyp.Kind() != reflect.Ptr {
break return errors.Errorf("walk expects a []*Model or *Model, but it was passed a %v instead", p.Type())
} }
}
if cri == nil {
return fmt.Errorf("Relation not found: %s.%s", pri.Name, cvt.Name)
}
ptyp := p.Type()
if ptyp.Kind() == reflect.Slice {
ptyp = ptyp.Elem()
}
if ptyp.Kind() != reflect.Ptr {
return fmt.Errorf("walk expects a []*Model or *Model, but it was passed a %v instead", p.Type())
}
for _, cri := range pri.Children {
freshAddr := reflect.New(reflect.SliceOf(cri.Type)) freshAddr := reflect.New(reflect.SliceOf(cri.Type))
var ps reflect.Value var ps reflect.Value
...@@ -136,9 +70,9 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error { ...@@ -136,9 +70,9 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error {
} }
var err error var err error
freshAddr, err = c.fetchPagedByPK(conn, cri.Relationship.ForeignDBNames[0], keys, reflect.SliceOf(cri.Type), cvt.Search) freshAddr, err = c.fetchPagedByPK(conn, cri.Relationship.ForeignDBNames[0], keys, reflect.SliceOf(cri.Type), cri.Field.Search())
if err != nil { if err != nil {
return errors.Wrap(err, "fetching has_many records (paginated)") return errors.WithMessage(err, "fetching has_many records (paginated)")
} }
pByFK := make(map[interface{}]reflect.Value) pByFK := make(map[interface{}]reflect.Value)
...@@ -149,7 +83,7 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error { ...@@ -149,7 +83,7 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error {
// reset slices so if preload is called more than once, // reset slices so if preload is called more than once,
// it doesn't keep appending // it doesn't keep appending
field := rec.Elem().FieldByName(cvt.Name) field := rec.Elem().FieldByName(cri.Name())
field.Set(reflect.New(field.Type()).Elem()) field.Set(reflect.New(field.Type()).Elem())
} }
...@@ -157,7 +91,7 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error { ...@@ -157,7 +91,7 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error {
for i := 0; i < fresh.Len(); i++ { for i := 0; i < fresh.Len(); i++ {
fk := fresh.Index(i).Elem().FieldByName(cri.Relationship.ForeignFieldNames[0]).Interface() fk := fresh.Index(i).Elem().FieldByName(cri.Relationship.ForeignFieldNames[0]).Interface()
if p, ok := pByFK[fk]; ok { if p, ok := pByFK[fk]; ok {
dest := p.Elem().FieldByName(cvt.Name) dest := p.Elem().FieldByName(cri.Name())
dest.Set(reflect.Append(dest, fresh.Index(i))) dest.Set(reflect.Append(dest, fresh.Index(i)))
} }
} }
...@@ -169,9 +103,9 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error { ...@@ -169,9 +103,9 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error {
} }
var err error var err error
freshAddr, err = c.fetchPagedByPK(conn, cri.Relationship.ForeignDBNames[0], keys, reflect.SliceOf(cri.Type), cvt.Search) freshAddr, err = c.fetchPagedByPK(conn, cri.Relationship.ForeignDBNames[0], keys, reflect.SliceOf(cri.Type), cri.Field.Search())
if err != nil { if err != nil {
return errors.Wrap(err, "fetching has_one records (paginated)") return errors.WithMessage(err, "fetching has_one records (paginated)")
} }
fresh := freshAddr.Elem() fresh := freshAddr.Elem()
...@@ -186,7 +120,7 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error { ...@@ -186,7 +120,7 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error {
prec := ps.Index(i) prec := ps.Index(i)
fk := prec.Elem().FieldByName(cri.Relationship.AssociationForeignFieldNames[0]).Interface() fk := prec.Elem().FieldByName(cri.Relationship.AssociationForeignFieldNames[0]).Interface()
if crec, ok := freshByFK[fk]; ok { if crec, ok := freshByFK[fk]; ok {
prec.Elem().FieldByName(cvt.Name).Set(crec) prec.Elem().FieldByName(cri.Name()).Set(crec)
} }
} }
case "belongs_to": case "belongs_to":
...@@ -197,9 +131,9 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error { ...@@ -197,9 +131,9 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error {
} }
var err error var err error
freshAddr, err = c.fetchPagedByPK(conn, cri.Relationship.AssociationForeignDBNames[0], keys, reflect.SliceOf(cri.Type), cvt.Search) freshAddr, err = c.fetchPagedByPK(conn, cri.Relationship.AssociationForeignDBNames[0], keys, reflect.SliceOf(cri.Type), cri.Field.Search())
if err != nil { if err != nil {
return errors.Wrap(err, "fetching belongs_to records (paginated)") return errors.WithMessage(err, "fetching belongs_to records (paginated)")
} }
fresh := freshAddr.Elem() fresh := freshAddr.Elem()
...@@ -214,23 +148,23 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error { ...@@ -214,23 +148,23 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error {
prec := ps.Index(i) prec := ps.Index(i)
fk := prec.Elem().FieldByName(cri.Relationship.ForeignFieldNames[0]).Interface() fk := prec.Elem().FieldByName(cri.Relationship.ForeignFieldNames[0]).Interface()
if crec, ok := freshByFK[fk]; ok { if crec, ok := freshByFK[fk]; ok {
prec.Elem().FieldByName(cvt.Name).Set(crec) prec.Elem().FieldByName(cri.Name()).Set(crec)
} }
} }
default: default:
return fmt.Errorf("Preload doesn't know how to handle %s relationships", cri.Relationship.Kind) return errors.Errorf("Preload doesn't know how to handle %s relationships", cri.Relationship.Kind)
} }
fresh := freshAddr.Elem() fresh := freshAddr.Elem()
err = walk(fresh, cri, cvt) err = walk(fresh, cri)
if err != nil { if err != nil {
return errors.WithStack(err) return errors.WithStack(err)
} }
} }
return nil return nil
} }
err = walk(val, typeTree, valTree) err = walk(val, rootInfo)
if err != nil { if err != nil {
return errors.WithStack(err) return errors.WithStack(err)
} }
......
...@@ -24,20 +24,10 @@ func Test_PreloadEdgeCases(t *testing.T) { ...@@ -24,20 +24,10 @@ func Test_PreloadEdgeCases(t *testing.T) {
withContext(t, models, func(conn *sqlite.Conn, c *hades.Context) { withContext(t, models, func(conn *sqlite.Conn, c *hades.Context) {
// non-existent Bar // non-existent Bar
f := &Foo{ID: 1, BarID: 999} f := &Foo{ID: 1, BarID: 999}
wtest.Must(t, c.Preload(conn, &hades.PreloadParams{ wtest.Must(t, c.Preload(conn, f, hades.Assoc("Bar")))
Record: f,
Fields: []hades.PreloadField{
{Name: "Bar"},
},
}))
// empty slice // empty slice
var foos []*Foo var foos []*Foo
wtest.Must(t, c.Preload(conn, &hades.PreloadParams{ wtest.Must(t, c.Preload(conn, foos, hades.Assoc("Bar")))
Record: foos,
Fields: []hades.PreloadField{
{Name: "Bar"},
},
}))
}) })
} }
package hades package hades
import ( import (
"fmt"
"reflect" "reflect"
"strings"
"github.com/go-xorm/builder" "github.com/go-xorm/builder"
"github.com/itchio/hades/sqliteutil2" "github.com/itchio/hades/sqliteutil2"
...@@ -14,35 +14,16 @@ import ( ...@@ -14,35 +14,16 @@ import (
type AllEntities map[reflect.Type]EntityMap type AllEntities map[reflect.Type]EntityMap
type EntityMap []interface{} type EntityMap []interface{}
type SaveParams struct { func (c *Context) Save(conn *sqlite.Conn, rec interface{}, opts ...SaveParam) (err error) {
// Record to save
Record interface{}
// Fields to save instead of the top-level record
Assocs []string
// For has_many and many_to_many, never delete rows for these models
DontCull []interface{}
}
func (c *Context) SaveOne(conn *sqlite.Conn, record interface{}) (err error) {
return c.SaveNoTransaction(conn, &SaveParams{
Record: record,
})
}
func (c *Context) Save(conn *sqlite.Conn, params *SaveParams) (err error) {
defer sqliteutil2.Save(conn)(&err) defer sqliteutil2.Save(conn)(&err)
return c.SaveNoTransaction(conn, rec, opts...)
return c.SaveNoTransaction(conn, params)
} }
func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error { func (c *Context) SaveNoTransaction(conn *sqlite.Conn, rec interface{}, opts ...SaveParam) error {
if params == nil { var params saveParams
return errors.New("Save: params cannot be nil") for _, o := range opts {
o.ApplyToSaveParams(&params)
} }
rec := params.Record
assocs := params.Assocs
val := reflect.ValueOf(rec) val := reflect.ValueOf(rec)
valtyp := val.Type() valtyp := val.Type()
...@@ -50,13 +31,18 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error ...@@ -50,13 +31,18 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error
valtyp = valtyp.Elem() valtyp = valtyp.Elem()
} }
if valtyp.Kind() != reflect.Ptr { if valtyp.Kind() != reflect.Ptr {
return fmt.Errorf("Save expects a []*Model or a *Model, but it was passed a %v instead", val.Type()) return errors.Errorf("Save expects a []*Model or a *Model, but it was passed a %v instead", val.Type())
} }
riMap := make(RecordInfoMap) riMap := make(RecordInfoMap)
tree, err := c.WalkType(riMap, "<root>", valtyp, make(VisitMap), assocs) rootField := &assocField{
name: "<root>",
mode: AssocModeAppend,
children: params.assocs,
}
rootRecordInfo, err := c.WalkType(riMap, rootField, valtyp)
if err != nil { if err != nil {
return errors.Wrap(err, "walking records to be saved") return errors.WithMessage(err, "walking records to be saved")
} }
entities := make(AllEntities) entities := make(AllEntities)
...@@ -75,7 +61,7 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error ...@@ -75,7 +61,7 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error
switch vri.Relationship.Kind { switch vri.Relationship.Kind {
case "has_many", "has_one": case "has_many", "has_one":
if len(pri.ModelStruct.PrimaryFields) != 1 { if len(pri.ModelStruct.PrimaryFields) != 1 {
return fmt.Errorf("Since %v %s %v, we expected one primary key in %v, but found %d", return errors.Errorf("Since %v %s %v, we expected one primary key in %v, but found %d",
p.Type(), p.Type(),
vri.Relationship.Kind, vri.Relationship.Kind,
v.Type(), v.Type(),
...@@ -85,7 +71,7 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error ...@@ -85,7 +71,7 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error
} }
pkField := p.Elem().FieldByName(pri.ModelStruct.PrimaryFields[0].Name) pkField := p.Elem().FieldByName(pri.ModelStruct.PrimaryFields[0].Name)
if len(vri.Relationship.ForeignFieldNames) != 1 { if len(vri.Relationship.ForeignFieldNames) != 1 {
return fmt.Errorf("Since %v %s %v, we expected one foreign field in %v, but found %d", return errors.Errorf("Since %v %s %v, we expected one foreign field in %v, but found %d",
p.Type(), p.Type(),
vri.Relationship.Kind, vri.Relationship.Kind,
v.Type(), v.Type(),
...@@ -97,7 +83,7 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error ...@@ -97,7 +83,7 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error
fkField.Set(pkField) fkField.Set(pkField)
case "belongs_to": case "belongs_to":
if len(vri.ModelStruct.PrimaryFields) != 1 { if len(vri.ModelStruct.PrimaryFields) != 1 {
return fmt.Errorf("Since %v %s %v, we expected one primary key in %v, but found %d", return errors.Errorf("Since %v %s %v, we expected one primary key in %v, but found %d",
p.Type(), p.Type(),
vri.Relationship.Kind, vri.Relationship.Kind,
v.Type(), v.Type(),
...@@ -108,7 +94,7 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error ...@@ -108,7 +94,7 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error
pkField := v.Elem().FieldByName(vri.ModelStruct.PrimaryFields[0].Name) pkField := v.Elem().FieldByName(vri.ModelStruct.PrimaryFields[0].Name)
if len(vri.Relationship.ForeignFieldNames) != 1 { if len(vri.Relationship.ForeignFieldNames) != 1 {
return fmt.Errorf("Since %v %s %v, we expected one foreign field in %v, but found %d", return errors.Errorf("Since %v %s %v, we expected one foreign field in %v, but found %d",
p.Type(), p.Type(),
vri.Relationship.Kind, vri.Relationship.Kind,
v.Type(), v.Type(),
...@@ -126,21 +112,21 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error ...@@ -126,21 +112,21 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error
numVisited++ numVisited++
err := addEntity(v) err := addEntity(v)
if err != nil { if err != nil {
return errors.Wrap(err, "adding entity") return errors.WithMessage(err, "adding entity")
} }
} }
if v.Kind() != reflect.Ptr { if v.Kind() != reflect.Ptr {
return fmt.Errorf("expected a pointer, but got with %v", v) return errors.Errorf("expected a pointer, but got with %v", v)
} }
vs := v.Elem() vs := v.Elem()
if vs.Kind() != reflect.Struct { if vs.Kind() != reflect.Struct {
return fmt.Errorf("expected a struct, but got with %v", v) return errors.Errorf("expected a struct, but got with %v", v)
} }
for _, childRi := range vri.Children { for _, childRi := range vri.Children {
child := vs.FieldByName(childRi.Name) child := vs.FieldByName(childRi.Name())
if !child.IsValid() { if !child.IsValid() {
continue continue
} }
...@@ -153,7 +139,7 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error ...@@ -153,7 +139,7 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error
persistChildren := true persistChildren := true
err := walk(v, vri, child, childRi, persistChildren) err := walk(v, vri, child, childRi, persistChildren)
if err != nil { if err != nil {
return errors.Wrap(err, "walking child entities to be saved") return errors.WithMessage(err, "walking child entities to be saved")
} }
} }
return nil return nil
...@@ -162,15 +148,11 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error ...@@ -162,15 +148,11 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error
walk = func(p reflect.Value, pri *RecordInfo, v reflect.Value, vri *RecordInfo, persist bool) error { walk = func(p reflect.Value, pri *RecordInfo, v reflect.Value, vri *RecordInfo, persist bool) error {
if v.Kind() == reflect.Slice { if v.Kind() == reflect.Slice {
cull := false cull := false
if vri.Relationship != nil { if vri.Relationship != nil {
switch vri.Relationship.Kind { switch vri.Relationship.Kind {
case "has_many": case "has_many":
cull = true if vri.Field.Mode() == AssocModeReplace {
for _, dc := range params.DontCull { cull = true
if reflect.TypeOf(dc).Elem() == vri.ModelStruct.ModelType {
cull = false
}
} }
case "many_to_many": case "many_to_many":
// culling is done later, but let's record the ManyToMany now // culling is done later, but let's record the ManyToMany now
...@@ -181,7 +163,7 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error ...@@ -181,7 +163,7 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error
for i := 0; i < v.Len(); i++ { for i := 0; i < v.Len(); i++ {
err := visit(p, pri, v.Index(i), vri, persist) err := visit(p, pri, v.Index(i), vri, persist)
if err != nil { if err != nil {
return errors.Wrap(err, "walking slice of children") return errors.WithMessage(err, "walking slice of children")
} }
} }
...@@ -196,7 +178,12 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error ...@@ -196,7 +178,12 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error
parentPK := parentPF.Field parentPK := parentPF.Field
if len(vri.ModelStruct.PrimaryFields) != 1 { if len(vri.ModelStruct.PrimaryFields) != 1 {
return errors.Errorf("Since %v has_many %v", pri.Name, vri.Name) var pfNames []string
for _, pf := range vri.ModelStruct.PrimaryFields {
pfNames = append(pfNames, pf.Name)
}
return errors.Errorf("Since %v has_many %v, expected %v to have one primary key. Instead, it has primary fields: %s",
pri.Name, vri.Name, strings.Join(pfNames, ", "))
} }
valuePF := c.NewScope(v.Interface()).PrimaryField() valuePF := c.NewScope(v.Interface()).PrimaryField()
if valuePF == nil { if valuePF == nil {
...@@ -247,30 +234,30 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error ...@@ -247,30 +234,30 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error
} else { } else {
err := visit(p, pri, v, vri, persist) err := visit(p, pri, v, vri, persist)
if err != nil { if err != nil {
return errors.Wrap(err, "walking single child") return errors.WithMessage(err, "walking single child")
} }
} }
return nil return nil
} }
persistRoot := assocs == nil err = walk(reflect.Zero(reflect.TypeOf(0)), nil, val, rootRecordInfo, !params.omitRoot)
err = walk(reflect.Zero(reflect.TypeOf(0)), nil, val, tree, persistRoot)
if err != nil { if err != nil {
return errors.Wrap(err, "walking all records to be persisted") return errors.WithMessage(err, "walking all records to be persisted")
} }
for _, m := range entities { for typ, m := range entities {
err := c.saveRows(conn, params, m) ri := riMap[typ]
err := c.saveRows(conn, ri.Field.Mode(), m)
if err != nil { if err != nil {
return errors.Wrap(err, "saving rows") return errors.WithMessage(err, "saving rows")
} }
} }
for _, ri := range riMap { for _, ri := range riMap {
if ri.ManyToMany != nil { if ri.ManyToMany != nil {
err := c.saveJoins(params, conn, ri.ManyToMany) err := c.saveJoins(conn, ri.Field.Mode(), ri.ManyToMany)
if err != nil { if err != nil {
return errors.Wrap(err, "saving joins") return errors.WithMessage(err, "saving joins")
} }
} }
} }
......
...@@ -8,15 +8,7 @@ import ( ...@@ -8,15 +8,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
func (c *Context) saveJoins(params *SaveParams, conn *sqlite.Conn, mtm *ManyToMany) error { func (c *Context) saveJoins(conn *sqlite.Conn, mode AssocMode, mtm *ManyToMany) error {
cull := true
for _, dc := range params.DontCull {
if mtm.JoinTable == ToDBName(c.NewScope(dc).TableName()) {
cull = false
break
}
}
joinType := reflect.PtrTo(mtm.Scope.GetModelStruct().ModelType) joinType := reflect.PtrTo(mtm.Scope.GetModelStruct().ModelType)
getDestinKey := func(v reflect.Value) interface{} { getDestinKey := func(v reflect.Value) interface{} {
...@@ -28,7 +20,7 @@ func (c *Context) saveJoins(params *SaveParams, conn *sqlite.Conn, mtm *ManyToMa ...@@ -28,7 +20,7 @@ func (c *Context) saveJoins(params *SaveParams, conn *sqlite.Conn, mtm *ManyToMa
err := c.Select(conn, cacheAddr.Interface(), builder.Eq{mtm.SourceDBName: sourceKey}, nil) err := c.Select(conn, cacheAddr.Interface(), builder.Eq{mtm.SourceDBName: sourceKey}, nil)
if err != nil { if err != nil {
return errors.Wrap(err, "fetching cached records to compare later") return errors.WithMessage(err, "fetching cached records to compare later")
} }
cache := cacheAddr.Elem() cache := cacheAddr.Elem()
...@@ -60,7 +52,7 @@ func (c *Context) saveJoins(params *SaveParams, conn *sqlite.Conn, mtm *ManyToMa ...@@ -60,7 +52,7 @@ func (c *Context) saveJoins(params *SaveParams, conn *sqlite.Conn, mtm *ManyToMa
cf, err := DiffRecord(ifrec, icrec, mtm.Scope) cf, err := DiffRecord(ifrec, icrec, mtm.Scope)
if err != nil { if err != nil {
return errors.Wrap(err, "diffing database records") return errors.WithMessage(err, "diffing database records")
} }
if cf != nil { if cf != nil {
...@@ -78,14 +70,10 @@ func (c *Context) saveJoins(params *SaveParams, conn *sqlite.Conn, mtm *ManyToMa ...@@ -78,14 +70,10 @@ func (c *Context) saveJoins(params *SaveParams, conn *sqlite.Conn, mtm *ManyToMa
} }
} }
if !cull { if mode == AssocModeReplace && len(deletes) > 0 {
// Not deleting extra join records, as requested err := c.deletePagedByPK(conn, mtm.JoinTable, mtm.DestinDBName, deletes, builder.Eq{mtm.SourceDBName: sourceKey})
} else { if err != nil {
if len(deletes) > 0 { return errors.WithMessage(err, "deleting extraneous relations")
err := c.deletePagedByPK(conn, mtm.JoinTable, mtm.DestinDBName, deletes, builder.Eq{mtm.SourceDBName: sourceKey})
if err != nil {
return errors.Wrap(err, "deleting extraneous relations")
}
} }
} }
...@@ -95,7 +83,7 @@ func (c *Context) saveJoins(params *SaveParams, conn *sqlite.Conn, mtm *ManyToMa ...@@ -95,7 +83,7 @@ func (c *Context) saveJoins(params *SaveParams, conn *sqlite.Conn, mtm *ManyToMa
if rec.IsValid() { if rec.IsValid() {
err := c.Insert(conn, mtm.Scope, rec) err := c.Insert(conn, mtm.Scope, rec)
if err != nil { if err != nil {
return errors.Wrap(err, "creating new relation records") return errors.WithMessage(err, "creating new relation records")
} }
} else { } else {
// if not passed an explicit record, make it ourselves // if not passed an explicit record, make it ourselves
...@@ -117,7 +105,7 @@ func (c *Context) saveJoins(params *SaveParams, conn *sqlite.Conn, mtm *ManyToMa ...@@ -117,7 +105,7 @@ func (c *Context) saveJoins(params *SaveParams, conn *sqlite.Conn, mtm *ManyToMa
query := builder.Update(cf.ToEq()).Into(mtm.Scope.TableName()).Where(builder.Eq{mtm.SourceDBName: sourceKey, mtm.DestinDBName: destinKey}) query := builder.Update(cf.ToEq()).Into(mtm.Scope.TableName()).Where(builder.Eq{mtm.SourceDBName: sourceKey, mtm.DestinDBName: destinKey})
err := c.Exec(conn, query, nil) err := c.Exec(conn, query, nil)
if err != nil { if err != nil {
return errors.Wrap(err, "updating related records") return errors.WithMessage(err, "updating related records")