Commit a03f88d2 authored by Amos Wenger's avatar Amos Wenger

nostructs style

parent 33fb7d0c
Pipeline #11010 failed with stage
in 18 seconds
......@@ -46,7 +46,7 @@ func Test_AutoMigrate(t *testing.T) {
assert.False(t, pti[1].PrimaryKey)
assert.False(t, pti[1].NotNull)
ordie(c.SaveOne(conn, &User{ID: 123, FirstName: "Joanna"}))
ordie(c.Save(conn, &User{ID: 123, FirstName: "Joanna"}))
u := &User{}
foundUser, err := c.SelectOne(conn, u, builder.Eq{"id": 123})
ordie(err)
......@@ -193,7 +193,7 @@ func Test_AutoMigrateAllValidTypes(t *testing.T) {
FirstName: "Jeremy",
HeartRate: 3.14,
}
ordie(c.SaveOne(conn, h1))
ordie(c.Save(conn, h1))
h2 := &Humanoid{}
found, err := c.SelectOne(conn, h2, builder.Eq{"id": 12})
......
......@@ -36,22 +36,18 @@ func Test_BelongsTo(t *testing.T) {
Desc: "Consumer-grade flamethrowers",
}
t.Log("Saving one fate")
wtest.Must(t, c.SaveOne(conn, someFate))
wtest.Must(t, c.Save(conn, someFate))
lea := &Human{
ID: 3,
FateID: someFate.ID,
}
t.Log("Saving one human")
wtest.Must(t, c.SaveOne(conn, lea))
wtest.Must(t, c.Save(conn, lea))
t.Log("Preloading lea")
c.Preload(conn, &hades.PreloadParams{
Record: lea,
Fields: []hades.PreloadField{
{Name: "Fate"},
},
})
c.Preload(conn, lea, hades.Assoc("Fate"))
assert.NotNil(t, lea.Fate)
assert.EqualValues(t, someFate.Desc, lea.Fate.Desc)
})
......@@ -64,10 +60,7 @@ func Test_BelongsTo(t *testing.T) {
Desc: "Book authorship",
},
}
c.Save(conn, &hades.SaveParams{
Record: lea,
Assocs: []string{"Fate"},
})
wtest.Must(t, c.Save(conn, lea, hades.Assoc("Fate")))
fate := &Fate{}
found, err := c.SelectOne(conn, fate, builder.Eq{"id": 421})
......@@ -81,19 +74,19 @@ func Test_BelongsTo(t *testing.T) {
ID: 3,
Desc: "Space rodeo",
}
wtest.Must(t, c.SaveOne(conn, fate))
wtest.Must(t, c.Save(conn, fate))
human := &Human{
ID: 6,
FateID: 3,
}
wtest.Must(t, c.SaveOne(conn, human))
wtest.Must(t, c.Save(conn, human))
joke := &Joke{
ID: "neuf",
HumanID: 6,
}
wtest.Must(t, c.SaveOne(conn, joke))
wtest.Must(t, c.Save(conn, joke))
c.Preload(conn, &hades.PreloadParams{
Record: joke,
......
......@@ -37,7 +37,7 @@ func Test_Delete(t *testing.T) {
var count int64
var err error
wtest.Must(t, c.SaveOne(conn, stories))
wtest.Must(t, c.Save(conn, stories))
count, err = c.Count(conn, &Story{}, builder.NewCond())
wtest.Must(t, err)
......
......@@ -43,12 +43,12 @@ func Test_HasMany(t *testing.T) {
{ID: 11, Label: "Ability to not repeat oneself"},
},
}
wtest.Must(t, c.Save(conn, &hades.SaveParams{Record: p1}))
wtest.Must(t, c.Save(conn, p1, hades.Assoc("Qualities")))
assertCount(&Programmer{}, 1)
assertCount(&Quality{}, 3)
p1.Qualities[2].Label = "Inspiration again"
wtest.Must(t, c.Save(conn, &hades.SaveParams{Record: p1}))
wtest.Must(t, c.Save(conn, p1, hades.Assoc("Qualities")))
assertCount(&Programmer{}, 1)
assertCount(&Quality{}, 3)
{
......@@ -68,20 +68,15 @@ func Test_HasMany(t *testing.T) {
}
programmers := []*Programmer{p1, p2}
wtest.Must(t, c.Save(conn, &hades.SaveParams{Record: programmers}))
wtest.Must(t, c.Save(conn, programmers, hades.Assoc("Qualities")))
assertCount(&Programmer{}, 2)
assertCount(&Quality{}, 5)
p1bis := &Programmer{ID: 3}
pp := &hades.PreloadParams{
Record: p1bis,
Fields: []hades.PreloadField{
{Name: "Qualities"},
},
}
wtest.Must(t, c.Preload(conn, pp))
wtest.Must(t, c.Preload(conn, p1bis, hades.Field("Qualities")))
assert.EqualValues(t, 3, len(p1bis.Qualities), "preload has_many")
wtest.Must(t, c.Preload(conn, pp))
wtest.Must(t, c.Preload(conn, p1bis, hades.Field("Qualities")))
assert.EqualValues(t, 3, len(p1bis.Qualities), "preload replaces, doesn't append")
pp.Fields[0] = hades.PreloadField{
......
......@@ -237,7 +237,7 @@ func Test_ManyToManyThorough(t *testing.T) {
{
beforeSaveQueryCount := c.QueryCount
ordie(c.SaveOne(conn, p))
ordie(c.Save(conn, p))
pieceSelect := 1
pieceInsert := 1
......@@ -271,7 +271,7 @@ func Test_ManyToManyThorough(t *testing.T) {
{
beforeSaveQueryCount := c.QueryCount
ordie(c.SaveOne(conn, p))
ordie(c.Save(conn, p))
pieceSelect := 1
......@@ -297,7 +297,7 @@ func Test_ManyToManyThorough(t *testing.T) {
{
beforeSaveQueryCount := c.QueryCount
ordie(c.SaveOne(conn, p))
ordie(c.Save(conn, p))
pieceSelect := 1
......@@ -328,7 +328,7 @@ func Test_ManyToManyThorough(t *testing.T) {
{
beforeSaveQueryCount := c.QueryCount
ordie(c.SaveOne(conn, p))
ordie(c.Save(conn, p))
pieceSelect := 1
......@@ -358,14 +358,14 @@ func Test_ManyToManyThorough(t *testing.T) {
})
}
ordie(c.SaveOne(conn, p))
ordie(c.Save(conn, p))
assertCount(&Piece{}, 1)
assertCount(&Author{}, len(originalAuthors)+1+1200)
assertCount(&PieceAuthor{}, len(p.Authors))
p.Authors = nil
ordie(c.SaveOne(conn, p))
ordie(c.Save(conn, p))
assertCount(&Piece{}, 1)
assertCount(&Author{}, len(originalAuthors)+1+1200)
......
......@@ -52,7 +52,7 @@ func Test_Null(t *testing.T) {
ID: 123,
}
ordie(c.SaveOne(conn, d))
ordie(c.Save(conn, d))
{
dd := &Download{}
found, err := c.SelectOne(conn, dd, builder.Eq{"id": 123})
......@@ -74,7 +74,7 @@ func Test_Null(t *testing.T) {
finishedAt := time.Now()
d.FinishedAt = &finishedAt
ordie(c.SaveOne(conn, d))
ordie(c.Save(conn, d))
{
dd := &Download{}
......@@ -89,7 +89,7 @@ func Test_Null(t *testing.T) {
}
d.ErrorMessage = nil
ordie(c.SaveOne(conn, d))
ordie(c.Save(conn, d))
{
dd := &Download{}
......
package hades
type AssocMode int
const (
AssocModeAppend AssocMode = iota
AssocModeReplace
)
type assocField struct {
name string
search *SearchParams
mode AssocMode
children []AssocField
}
type saveParams struct {
assocs []AssocField
omitRoot bool
}
type preloadParams struct {
assocs []AssocField
}
type SaveParam interface {
ApplyToSaveParams(sp *saveParams)
}
type PreloadParam interface {
ApplyToPreloadParams(pp *preloadParams)
}
type AssocField interface {
SaveParam
PreloadParam
Name() string
Mode() AssocMode
Search() *SearchParams
Children() []AssocField
}
// -------------
// OmitRoot tells save to not save the record passed,
// but only associations
func OmitRoot() SaveParam {
return &omitRoot{}
}
type omitRoot struct{}
func (o *omitRoot) ApplyToSaveParams(sp *saveParams) {
sp.omitRoot = true
}
// Assoc tells save to save the specified association,
// but not to remove any existing associated records, even if
// they're not listed anymore
func Assoc(fieldName string, children ...AssocField) AssocField {
return &assocField{
name: fieldName,
mode: AssocModeAppend,
children: children,
}
}
// AssocReplace tells save to save the specified assocation,
// and to remove any associated records that are no longer listed
func AssocReplace(fieldName string, children ...AssocField) AssocField {
return &assocField{
name: fieldName,
mode: AssocModeReplace,
children: children,
}
}
func AssocWithSearch(fieldName string, search *SearchParams, children ...AssocField) AssocField {
return &assocField{
name: fieldName,
mode: AssocModeAppend,
search: search,
children: children,
}
}
func (f *assocField) ApplyToSaveParams(sp *saveParams) {
sp.assocs = append(sp.assocs, f)
}
func (f *assocField) ApplyToPreloadParams(pp *preloadParams) {
pp.assocs = append(pp.assocs, f)
}
func (f *assocField) Name() string {
return f.name
}
func (f *assocField) Mode() AssocMode {
return f.mode
}
func (f *assocField) Children() []AssocField {
return f.children
}
func (f *assocField) Search() *SearchParams {
return f.search
}
......@@ -3,73 +3,15 @@ package hades
import (
"fmt"
"reflect"
"strings"
"crawshaw.io/sqlite"
"github.com/pkg/errors"
)
type PreloadParams struct {
Record interface{}
// Fields to preload, for example []string{"CollectionGames", "CollectionGames.Game"}
Fields []PreloadField
}
type PreloadField struct {
Name string
Search *SearchParams
}
type Node struct {
Name string
Search *SearchParams
Field PreloadField
Children map[string]*Node
}
func NewNode(name string) *Node {
return &Node{
Name: name,
Children: make(map[string]*Node),
}
}
func (n *Node) String() string {
var res []string
res = append(res, fmt.Sprintf("- %s%s", n.Name, n.Search))
for _, c := range n.Children {
for _, cl := range strings.Split(c.String(), "\n") {
res = append(res, " "+cl)
}
}
return strings.Join(res, "\n")
}
func (n *Node) Add(pf PreloadField) {
tokens := strings.Split(pf.Name, ".")
name := tokens[0]
c, ok := n.Children[name]
if !ok {
c = NewNode(name)
n.Children[name] = c
}
if len(tokens) > 1 {
pfc := pf
pfc.Name = strings.Join(tokens[1:], ".")
c.Add(pfc)
} else {
c.Field = pf
c.Search = pf.Search
}
}
func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error {
rec := params.Record
if len(params.Fields) == 0 {
return errors.New("Preload expects a non-empty list in Fields")
func (c *Context) Preload(conn *sqlite.Conn, rec interface{}, opts ...PreloadParam) error {
params := &preloadParams{}
for _, o := range opts {
o.ApplyToPreloadParams(params)
}
val := reflect.ValueOf(rec)
......@@ -85,39 +27,27 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error {
}
riMap := make(RecordInfoMap)
rootName := fmt.Sprintf("%v", valtyp)
typeTree, err := c.WalkType(riMap, rootName, valtyp, make(VisitMap), nil)
rootField := &assocField{
name: fmt.Sprintf("%v", valtyp),
mode: AssocModeAppend,
children: params.assocs,
}
rootInfo, err := c.WalkType(riMap, rootField, valtyp)
if err != nil {
return errors.Wrap(err, "waking type tree")
}
valTree := NewNode(rootName)
for _, field := range params.Fields {
valTree.Add(field)
}
var walk func(p reflect.Value, pri *RecordInfo, pvt *Node) error
walk = func(p reflect.Value, pri *RecordInfo, pvt *Node) error {
for _, cvt := range pvt.Children {
var cri *RecordInfo
for _, c := range pri.Children {
if c.Name == cvt.Name {
cri = c
break
}
}
if cri == nil {
return fmt.Errorf("Relation not found: %s.%s", pri.Name, cvt.Name)
}
ptyp := p.Type()
if ptyp.Kind() == reflect.Slice {
ptyp = ptyp.Elem()
}
if ptyp.Kind() != reflect.Ptr {
return fmt.Errorf("walk expects a []*Model or *Model, but it was passed a %v instead", p.Type())
}
var walk func(p reflect.Value, pri *RecordInfo) error
walk = func(p reflect.Value, pri *RecordInfo) error {
ptyp := p.Type()
if ptyp.Kind() == reflect.Slice {
ptyp = ptyp.Elem()
}
if ptyp.Kind() != reflect.Ptr {
return fmt.Errorf("walk expects a []*Model or *Model, but it was passed a %v instead", p.Type())
}
for _, cri := range pri.Children {
freshAddr := reflect.New(reflect.SliceOf(cri.Type))
var ps reflect.Value
......@@ -136,7 +66,7 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error {
}
var err error
freshAddr, err = c.fetchPagedByPK(conn, cri.Relationship.ForeignDBNames[0], keys, reflect.SliceOf(cri.Type), cvt.Search)
freshAddr, err = c.fetchPagedByPK(conn, cri.Relationship.ForeignDBNames[0], keys, reflect.SliceOf(cri.Type), cri.Field.Search())
if err != nil {
return errors.Wrap(err, "fetching has_many records (paginated)")
}
......@@ -149,7 +79,7 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error {
// reset slices so if preload is called more than once,
// it doesn't keep appending
field := rec.Elem().FieldByName(cvt.Name)
field := rec.Elem().FieldByName(cri.Name())
field.Set(reflect.New(field.Type()).Elem())
}
......@@ -157,7 +87,7 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error {
for i := 0; i < fresh.Len(); i++ {
fk := fresh.Index(i).Elem().FieldByName(cri.Relationship.ForeignFieldNames[0]).Interface()
if p, ok := pByFK[fk]; ok {
dest := p.Elem().FieldByName(cvt.Name)
dest := p.Elem().FieldByName(cri.Name())
dest.Set(reflect.Append(dest, fresh.Index(i)))
}
}
......@@ -169,7 +99,7 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error {
}
var err error
freshAddr, err = c.fetchPagedByPK(conn, cri.Relationship.ForeignDBNames[0], keys, reflect.SliceOf(cri.Type), cvt.Search)
freshAddr, err = c.fetchPagedByPK(conn, cri.Relationship.ForeignDBNames[0], keys, reflect.SliceOf(cri.Type), cri.Field.Search())
if err != nil {
return errors.Wrap(err, "fetching has_one records (paginated)")
}
......@@ -186,7 +116,7 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error {
prec := ps.Index(i)
fk := prec.Elem().FieldByName(cri.Relationship.AssociationForeignFieldNames[0]).Interface()
if crec, ok := freshByFK[fk]; ok {
prec.Elem().FieldByName(cvt.Name).Set(crec)
prec.Elem().FieldByName(cri.Name()).Set(crec)
}
}
case "belongs_to":
......@@ -197,7 +127,7 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error {
}
var err error
freshAddr, err = c.fetchPagedByPK(conn, cri.Relationship.AssociationForeignDBNames[0], keys, reflect.SliceOf(cri.Type), cvt.Search)
freshAddr, err = c.fetchPagedByPK(conn, cri.Relationship.AssociationForeignDBNames[0], keys, reflect.SliceOf(cri.Type), cri.Field.Search())
if err != nil {
return errors.Wrap(err, "fetching belongs_to records (paginated)")
}
......@@ -214,7 +144,7 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error {
prec := ps.Index(i)
fk := prec.Elem().FieldByName(cri.Relationship.ForeignFieldNames[0]).Interface()
if crec, ok := freshByFK[fk]; ok {
prec.Elem().FieldByName(cvt.Name).Set(crec)
prec.Elem().FieldByName(cri.Name()).Set(crec)
}
}
default:
......@@ -223,14 +153,14 @@ func (c *Context) Preload(conn *sqlite.Conn, params *PreloadParams) error {
fresh := freshAddr.Elem()
err = walk(fresh, cri, cvt)
err = walk(fresh, cri)
if err != nil {
return errors.WithStack(err)
}
}
return nil
}
err = walk(val, typeTree, valTree)
err = walk(val, rootInfo)
if err != nil {
return errors.WithStack(err)
}
......
......@@ -2,7 +2,9 @@ package hades
import (
"fmt"
"log"
"reflect"
"strings"
"github.com/go-xorm/builder"
"github.com/itchio/hades/sqliteutil2"
......@@ -14,35 +16,16 @@ import (
type AllEntities map[reflect.Type]EntityMap
type EntityMap []interface{}
type SaveParams struct {
// Record to save
Record interface{}
// Fields to save instead of the top-level record
Assocs []string
// For has_many and many_to_many, never delete rows for these models
DontCull []interface{}
}
func (c *Context) SaveOne(conn *sqlite.Conn, record interface{}) (err error) {
return c.SaveNoTransaction(conn, &SaveParams{
Record: record,
})
}
func (c *Context) Save(conn *sqlite.Conn, params *SaveParams) (err error) {
func (c *Context) Save(conn *sqlite.Conn, rec interface{}, opts ...SaveParam) (err error) {
defer sqliteutil2.Save(conn)(&err)
return c.SaveNoTransaction(conn, params)
return c.SaveNoTransaction(conn, rec, opts...)
}
func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error {
if params == nil {
return errors.New("Save: params cannot be nil")
func (c *Context) SaveNoTransaction(conn *sqlite.Conn, rec interface{}, opts ...SaveParam) error {
var params saveParams
for _, o := range opts {
o.ApplyToSaveParams(&params)
}
rec := params.Record
assocs := params.Assocs
val := reflect.ValueOf(rec)
valtyp := val.Type()
......@@ -54,7 +37,12 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error
}
riMap := make(RecordInfoMap)
tree, err := c.WalkType(riMap, "<root>", valtyp, make(VisitMap), assocs)
rootField := &assocField{
name: "<root>",
mode: AssocModeAppend,
children: params.assocs,
}
rootRecordInfo, err := c.WalkType(riMap, rootField, valtyp)
if err != nil {
return errors.Wrap(err, "walking records to be saved")
}
......@@ -140,7 +128,7 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error
}
for _, childRi := range vri.Children {
child := vs.FieldByName(childRi.Name)
child := vs.FieldByName(childRi.Name())
if !child.IsValid() {
continue
}
......@@ -161,23 +149,6 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error
walk = func(p reflect.Value, pri *RecordInfo, v reflect.Value, vri *RecordInfo, persist bool) error {
if v.Kind() == reflect.Slice {
cull := false
if vri.Relationship != nil {
switch vri.Relationship.Kind {
case "has_many":
cull = true
for _, dc := range params.DontCull {
if reflect.TypeOf(dc).Elem() == vri.ModelStruct.ModelType {
cull = false
}
}
case "many_to_many":
// culling is done later, but let's record the ManyToMany now
vri.ManyToMany.Mark(p)
}
}
for i := 0; i < v.Len(); i++ {
err := visit(p, pri, v.Index(i), vri, persist)
if err != nil {
......@@ -185,7 +156,7 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error
}
}
if cull {
if vri.Field.Mode() == AssocModeReplace {
var oldValuePKs []string
rel := vri.Relationship
......@@ -195,8 +166,17 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error
}
parentPK := parentPF.Field
log.Printf("%v has_many %v:", pri.Type, vri.Type)
log.Printf("AssociationForeignDBNames: %v", vri.Relationship.AssociationForeignDBNames)
log.Printf(" ForeignDBNames: %v", vri.Relationship.ForeignDBNames)
if len(vri.ModelStruct.PrimaryFields) != 1 {
return errors.Errorf("Since %v has_many %v", pri.Name, vri.Name)
var pfNames []string
for _, pf := range vri.ModelStruct.PrimaryFields {
pfNames = append(pfNames, pf.Name)
}
return errors.Errorf("Since %v has_many %v, expected %v to have one primary key. Instead, it has primary fields: %s",
pri.Name, vri.Name, strings.Join(pfNames, ", "))
}
valuePF := c.NewScope(v.Interface()).PrimaryField()
if valuePF == nil {
......@@ -253,14 +233,14 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error
return nil
}
persistRoot := assocs == nil
err = walk(reflect.Zero(reflect.TypeOf(0)), nil, val, tree, persistRoot)
err = walk(reflect.Zero(reflect.TypeOf(0)), nil, val, rootRecordInfo, !params.omitRoot)
if err != nil {
return errors.Wrap(err, "walking all records to be persisted")
}
for _, m := range entities {
err := c.saveRows(conn, params, m)
for typ, m := range entities {
ri := riMap[typ]
err := c.saveRows(conn, ri.Field.Mode(), m)
if err != nil {
return errors.Wrap(err, "saving rows")
}
......@@ -268,7 +248,7 @@ func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error
for _, ri := range riMap {
if ri.ManyToMany != nil {
err := c.saveJoins(params, conn, ri.ManyToMany)
err := c.saveJoins(conn, ri.Field.Mode(), ri.ManyToMany)
if err != nil {
return errors.Wrap(err, "saving joins")
}
......
......@@ -8,15 +8,7 @@ import (
"github.com/pkg/errors"
)
func (c *Context) saveJoins(params *SaveParams, conn *sqlite.Conn, mtm *ManyToMany) error {
cull := true
for _, dc := range params.DontCull {
if mtm.JoinTable == ToDBName(c.NewScope(dc).TableName()) {
cull = false
break