Commit f03ab1c1 authored by Amos Wenger's avatar Amos Wenger

Buncha cleanups and more tests

parent a98c44fe
Pipeline #10422 passed with stage
in 25 seconds
......@@ -6,7 +6,7 @@
[![GoDoc](https://godoc.org/github.com/itchio/hades?status.svg)](https://godoc.org/github.com/itchio/hades)
[![MIT licensed](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/itchio/hades/blob/master/LICENSE)
hades is a persistent layer based on sqlite.
hades is a persistence layer based on SQLite.
## License
......
......@@ -4,13 +4,13 @@ import (
"fmt"
"reflect"
"strings"
"time"
"crawshaw.io/sqlite"
"crawshaw.io/sqlite/sqliteutil"
"github.com/pkg/errors"
)
// TODO: if table already exists, just add fields
func (c *Context) AutoMigrate(conn *sqlite.Conn) error {
for _, m := range c.ScopeMap.byDBName {
err := c.syncTable(conn, m.GetModelStruct())
......@@ -44,7 +44,25 @@ func (c *Context) syncTable(conn *sqlite.Conn, ms *ModelStruct) (err error) {
oldColumns[ptir.Name] = ptir
}
// TODO: don't do anything if already good
numOldCols := len(oldColumns)
numNewCols := 0
isMissingCols := false
for _, sf := range ms.StructFields {
if sf.Relationship != nil {
continue
}
numNewCols++
if _, ok := oldColumns[sf.DBName]; !ok {
isMissingCols = true
break
}
}
if !isMissingCols && numOldCols == numNewCols {
// all done
return nil
}
tempName := fmt.Sprintf("__hades_migrate__%s__", tableName)
err = c.ExecRaw(conn, fmt.Sprintf("CREATE TABLE %s AS SELECT * FROM %s", tempName, tableName), nil)
......@@ -104,7 +122,7 @@ func (c *Context) createTable(conn *sqlite.Conn, ms *ModelStruct) error {
var columns []string
var pks []string
for _, sf := range ms.StructFields {
if sf.Relationship != nil {
if !sf.IsNormal {
continue
}
......@@ -116,6 +134,12 @@ func (c *Context) createTable(conn *sqlite.Conn, ms *ModelStruct) error {
sqliteType = "REAL"
case reflect.String:
sqliteType = "TEXT"
case reflect.Struct:
if sf.Struct.Type == reflect.TypeOf(time.Time{}) {
sqliteType = "DATETIME"
break
}
fallthrough
default:
return errors.Errorf("Unsupported model field type: %v (in model %v)", sf.Struct.Type, ms.ModelType)
}
......
......@@ -3,6 +3,7 @@ package hades_test
import (
"context"
"testing"
"time"
"crawshaw.io/sqlite"
"github.com/go-xorm/builder"
......@@ -96,6 +97,87 @@ func Test_AutoMigrate(t *testing.T) {
}
}
func Test_AutoMigrateNoPK(t *testing.T) {
dbpool, err := sqlite.Open("file:memory:?mode=memory", 0, 10)
ordie(err)
defer dbpool.Close()
conn := dbpool.Get(context.Background().Done())
defer dbpool.Put(conn)
type Humanoid struct {
Name string
}
models := []interface{}{&Humanoid{}}
c, err := hades.NewContext(makeConsumer(t), models...)
ordie(err)
c.Log = true
err = c.AutoMigrate(conn)
assert.Error(t, err)
}
func Test_AutoMigrateAllValidTypes(t *testing.T) {
dbpool, err := sqlite.Open("file:memory:?mode=memory", 0, 10)
ordie(err)
defer dbpool.Close()
conn := dbpool.Get(context.Background().Done())
defer dbpool.Put(conn)
type Humanoid struct {
ID int64
FirstName string
Alive bool
HeartRate float64
BornAt time.Time
Whatever struct {
Ohey string
ThisIsValid int64
} `hades:"-"`
}
models := []interface{}{&Humanoid{}}
c, err := hades.NewContext(makeConsumer(t), models...)
ordie(err)
c.Log = true
ordie(c.AutoMigrate(conn))
pti, err := c.PragmaTableInfo(conn, "humanoids")
ordie(err)
assert.EqualValues(t, 5, len(pti))
assert.EqualValues(t, "id", pti[0].Name)
assert.EqualValues(t, "INTEGER", pti[0].Type)
assert.True(t, pti[0].PrimaryKey)
assert.True(t, pti[0].NotNull)
assert.EqualValues(t, "first_name", pti[1].Name)
assert.EqualValues(t, "TEXT", pti[1].Type)
assert.False(t, pti[1].PrimaryKey)
assert.False(t, pti[1].NotNull)
assert.EqualValues(t, "alive", pti[2].Name)
assert.EqualValues(t, "INTEGER", pti[2].Type)
assert.False(t, pti[2].PrimaryKey)
assert.False(t, pti[2].NotNull)
assert.EqualValues(t, "heart_rate", pti[3].Name)
assert.EqualValues(t, "REAL", pti[3].Type)
assert.False(t, pti[3].PrimaryKey)
assert.False(t, pti[3].NotNull)
assert.EqualValues(t, "born_at", pti[4].Name)
assert.EqualValues(t, "DATETIME", pti[4].Type)
assert.False(t, pti[4].PrimaryKey)
assert.False(t, pti[4].NotNull)
}
func ordie(err error) {
if err != nil {
panic(err)
......
......@@ -5,11 +5,12 @@ import (
)
type Context struct {
ScopeMap *ScopeMap
Consumer *state.Consumer
Stats Stats
Error error
Log bool
ScopeMap *ScopeMap
Consumer *state.Consumer
Stats Stats
Error error
Log bool
QueryCount int64
}
type Stats struct {
......
......@@ -20,6 +20,8 @@ func (c *Context) Exec(conn *sqlite.Conn, b *builder.Builder, resultFn ResultFn)
}
func (c *Context) ExecRaw(conn *sqlite.Conn, query string, resultFn ResultFn, args ...interface{}) error {
c.QueryCount++
var startTime time.Time
if c.Log {
startTime = time.Now()
......
......@@ -328,14 +328,14 @@ func Test_ManyToMany(t *testing.T) {
assertCount(&Word{}, 2)
assertCount(&LanguageWord{}, 4)
t.Logf("saving partial joins ('add' words to english)")
t.Logf("saving without culling ('add' words to english)")
en.Words = []*Word{
{ID: "Wreck"},
{ID: "Nervous"},
}
wtest.Must(t, c.Save(conn, &hades.SaveParams{
Record: []*Language{en},
PartialJoins: []string{"LanguageWords"},
Record: []*Language{en},
DontCull: []interface{}{&LanguageWord{}},
}))
assertCount(&Language{}, 2)
......
package hades_test
import (
"context"
"fmt"
"testing"
"github.com/go-xorm/builder"
"github.com/stretchr/testify/assert"
"crawshaw.io/sqlite"
"github.com/itchio/hades"
)
func Test_HasManyThorough(t *testing.T) {
dbpool, err := sqlite.Open("file:memory:?mode=memory", 0, 10)
ordie(err)
defer dbpool.Close()
conn := dbpool.Get(context.Background().Done())
defer dbpool.Put(conn)
type Trait struct {
ID int64
CarID int64
Label string
}
type Car struct {
ID int64
Traits []*Trait
}
models := []interface{}{&Car{}, &Trait{}}
c, err := hades.NewContext(makeConsumer(t), models...)
ordie(err)
c.Log = true
ordie(c.AutoMigrate(conn))
// let's be terrible
car := &Car{ID: 123}
// the goal here is to go above SQLite's 999 variables limit
for i := 0; i < 1300; i++ {
car.Traits = append(car.Traits, &Trait{
ID: int64(i),
CarID: car.ID,
Label: fmt.Sprintf("car-trait-#%d", i),
})
}
traitCount, err := c.Count(conn, &Trait{}, builder.NewCond())
ordie(err)
assert.EqualValues(t, 0, traitCount, "no traits should exist before save")
t.Logf("...snip tons of INSERT...")
c.Log = false
ordie(c.Save(conn, &hades.SaveParams{
Record: car,
Assocs: []string{"Traits"},
}))
c.Log = true
numTraits := len(car.Traits)
traitCount, err = c.Count(conn, &Trait{}, builder.NewCond())
ordie(err)
assert.EqualValues(t, numTraits, traitCount, "all traits should exist after save")
car.Traits = nil
ordie(c.Save(conn, &hades.SaveParams{
Record: car,
Assocs: []string{"Traits"},
DontCull: []interface{}{&Trait{}},
}))
traitCount, err = c.Count(conn, &Trait{}, builder.NewCond())
ordie(err)
assert.EqualValues(t, numTraits, traitCount, "traits should still exist after partial-join save")
ordie(c.Save(conn, &hades.SaveParams{
Record: car,
Assocs: []string{"Traits"},
}))
traitCount, err = c.Count(conn, &Trait{}, builder.NewCond())
ordie(err)
assert.EqualValues(t, 0, traitCount, "no traits should exist after last save")
}
package hades_test
import (
"context"
"testing"
"github.com/go-xorm/builder"
"github.com/stretchr/testify/assert"
"crawshaw.io/sqlite"
"github.com/itchio/hades"
)
type Piece struct {
ID int64
Authors []*Author `hades:"many2many:piece_authors"`
}
type Author struct {
ID int64
Name string
Pieces []*Piece `hades:"many2many:piece_authors"`
}
type PieceAuthor struct {
AuthorID int64 `hades:"primary_key"`
PieceID int64 `hades:"primary_key"`
}
func Test_ManyToManyThorough(t *testing.T) {
dbpool, err := sqlite.Open("file:memory:?mode=memory", 0, 10)
ordie(err)
defer dbpool.Close()
conn := dbpool.Get(context.Background().Done())
defer dbpool.Put(conn)
models := []interface{}{&Piece{}, &Author{}, &PieceAuthor{}}
c, err := hades.NewContext(makeConsumer(t), models...)
ordie(err)
c.Log = true
ordie(c.AutoMigrate(conn))
assertCount := func(model interface{}, expected int) {
t.Helper()
actual, err := c.Count(conn, model, builder.NewCond())
ordie(err)
assert.EqualValues(t, expected, actual)
}
t.Logf("Creating 1 piece with 10 authors")
p := &Piece{ID: 321}
for i := 0; i < 10; i++ {
p.Authors = append(p.Authors, &Author{
ID: int64(i + 1000),
})
}
originalAuthors := p.Authors
{
beforeSaveQueryCount := c.QueryCount
ordie(c.SaveOne(conn, p))
pieceSelect := 1
pieceInsert := 1
authorSelect := 1
authorInsert := len(p.Authors)
pieceAuthorSelect := 1
pieceAuthorInsert := len(p.Authors)
total := pieceSelect + pieceInsert +
authorSelect + authorInsert +
pieceAuthorSelect + pieceAuthorInsert
assert.EqualValues(t, total, c.QueryCount-beforeSaveQueryCount)
}
assertCount(&Piece{}, 1)
assertCount(&Author{}, len(p.Authors))
assertCount(&PieceAuthor{}, len(p.Authors))
t.Logf("Disassociating 5 authors from piece")
var fewerAuthors []*Author
for i, author := range p.Authors {
if i%2 == 0 {
fewerAuthors = append(fewerAuthors, author)
}
}
p.Authors = fewerAuthors
{
beforeSaveQueryCount := c.QueryCount
ordie(c.SaveOne(conn, p))
pieceSelect := 1
authorSelect := 1
pieceAuthorSelect := 1
pieceAuthorDelete := 1
total := pieceSelect +
authorSelect +
pieceAuthorSelect + pieceAuthorDelete
assert.EqualValues(t, total, c.QueryCount-beforeSaveQueryCount)
}
assertCount(&Piece{}, 1)
assertCount(&Author{}, len(originalAuthors))
assertCount(&PieceAuthor{}, len(p.Authors))
t.Logf("Updating 1 author")
p.Authors[2].Name = "Hansel"
{
beforeSaveQueryCount := c.QueryCount
ordie(c.SaveOne(conn, p))
pieceSelect := 1
authorSelect := 1
authorUpdate := 1
pieceAuthorSelect := 1
total := pieceSelect +
authorSelect + authorUpdate +
pieceAuthorSelect
assert.EqualValues(t, total, c.QueryCount-beforeSaveQueryCount)
}
assertCount(&Piece{}, 1)
assertCount(&Author{}, len(originalAuthors))
assertCount(&PieceAuthor{}, len(p.Authors))
t.Logf("Updating 2 authors, adding 1, deleting 1")
p.Authors[0].Name = "Grieschka"
p.Authors[1].Name = "Peggy"
p.Authors = append(p.Authors[0:4], &Author{
ID: 2001,
Name: "Joseph",
})
{
beforeSaveQueryCount := c.QueryCount
ordie(c.SaveOne(conn, p))
pieceSelect := 1
authorSelect := 1
authorInsert := 1
authorUpdate := 2
pieceAuthorSelect := 1
pieceAuthorInsert := 1
pieceAuthorDelete := 1
total := pieceSelect +
authorSelect + authorInsert + authorUpdate +
pieceAuthorSelect + pieceAuthorInsert + pieceAuthorDelete
assert.EqualValues(t, total, c.QueryCount-beforeSaveQueryCount)
}
assertCount(&Piece{}, 1)
assertCount(&Author{}, len(originalAuthors)+1)
assertCount(&PieceAuthor{}, len(p.Authors))
}
package hades
import (
"database/sql"
"go/ast"
"reflect"
"strings"
......@@ -45,51 +44,23 @@ type ModelStruct struct {
// StructField model field's struct definition
type StructField struct {
DBName string
Name string
Names []string
IsPrimaryKey bool
IsNormal bool
IsIgnored bool
IsScanner bool
HasDefaultValue bool
Tag reflect.StructTag
TagSettings map[string]string
Struct reflect.StructField
IsForeignKey bool
Relationship *Relationship
}
func (structField *StructField) clone() *StructField {
clone := &StructField{
DBName: structField.DBName,
Name: structField.Name,
Names: structField.Names,
IsPrimaryKey: structField.IsPrimaryKey,
IsNormal: structField.IsNormal,
IsIgnored: structField.IsIgnored,
IsScanner: structField.IsScanner,
HasDefaultValue: structField.HasDefaultValue,
Tag: structField.Tag,
TagSettings: map[string]string{},
Struct: structField.Struct,
IsForeignKey: structField.IsForeignKey,
}
if structField.Relationship != nil {
relationship := *structField.Relationship
clone.Relationship = &relationship
}
for key, value := range structField.TagSettings {
clone.TagSettings[key] = value
}
return clone
DBName string
Name string
Names []string
IsPrimaryKey bool
IsNormal bool
IsIgnored bool
IsScanner bool
Tag reflect.StructTag
TagSettings map[string]string
Struct reflect.StructField
IsForeignKey bool
Relationship *Relationship
}
// Relationship described the relationship between models
type Relationship struct {
// belongs_to, has_one, has_many, many_to_many
Kind string
PolymorphicType string
PolymorphicDBName string
......@@ -156,63 +127,15 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
}
if _, ok := field.TagSettings["DEFAULT"]; ok {
field.HasDefaultValue = true
}
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsPrimaryKey {
field.HasDefaultValue = true
}
indirectType := fieldStruct.Type
for indirectType.Kind() == reflect.Ptr {
indirectType = indirectType.Elem()
}
fieldValue := reflect.New(indirectType).Interface()
if _, isScanner := fieldValue.(sql.Scanner); isScanner {
// is scanner
field.IsScanner, field.IsNormal = true, true
if indirectType.Kind() == reflect.Struct {
for i := 0; i < indirectType.NumField(); i++ {
for key, value := range parseTagSetting(indirectType.Field(i).Tag) {
if _, ok := field.TagSettings[key]; !ok {
field.TagSettings[key] = value
}
}
}
}
} else if _, isTime := fieldValue.(*time.Time); isTime {
if _, isTime := fieldValue.(*time.Time); isTime {
// is time
field.IsNormal = true
} else if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous {
// is embedded struct
for _, subField := range scope.New(fieldValue).GetModelStruct().StructFields {
subField = subField.clone()
subField.Names = append([]string{fieldStruct.Name}, subField.Names...)
if prefix, ok := field.TagSettings["EMBEDDED_PREFIX"]; ok {
subField.DBName = prefix + subField.DBName
}
if subField.IsPrimaryKey {
if _, ok := subField.TagSettings["PRIMARY_KEY"]; ok {
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField)
} else {
subField.IsPrimaryKey = false
}
}
if subField.Relationship != nil && subField.Relationship.JoinTableHandler != nil {
if joinTableHandler, ok := subField.Relationship.JoinTableHandler.(*JoinTableHandler); ok {
newJoinTableHandler := &JoinTableHandler{}
newJoinTableHandler.Setup(subField.Relationship, joinTableHandler.TableName, reflectType, joinTableHandler.Destination.ModelType)
subField.Relationship.JoinTableHandler = newJoinTableHandler
}
}
modelStruct.StructFields = append(modelStruct.StructFields, subField)
}
continue
} else {
// build relationships
switch indirectType.Kind() {
......@@ -316,23 +239,6 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
var toFields = toScope.GetStructFields()
relationship.Kind = "has_many"
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
// Dog has many toys, tag polymorphic is Owner, then associationType is Owner
// Toy use OwnerID, OwnerType ('dogs') as foreign key
if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil {
associationType = polymorphic
relationship.PolymorphicType = polymorphicType.Name
relationship.PolymorphicDBName = polymorphicType.DBName
// if Dog has multiple set of toys set name of the set (instead of default 'dogs')
if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok {
relationship.PolymorphicValue = value
} else {
relationship.PolymorphicValue = scope.TableName()
}
polymorphicType.IsForeignKey = true
}
}
// if no foreign keys defined with tag
if len(foreignKeys) == 0 {
// if no association foreign keys defined with tag
......@@ -416,23 +322,6 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
tagAssociationForeignKeys = strings.Split(foreignKey, ",")
}
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
// Cat has one toy, tag polymorphic is Owner, then associationType is Owner
// Toy use OwnerID, OwnerType ('cats') as foreign key
if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil {
associationType = polymorphic
relationship.PolymorphicType = polymorphicType.Name
relationship.PolymorphicDBName = polymorphicType.DBName
// if Cat has several different types of toys set name for each (instead of default 'cats')
if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok {
relationship.PolymorphicValue = value
} else {
relationship.PolymorphicValue = scope.TableName()
}
polymorphicType.IsForeignKey = true
}
}
// Has One
{
var foreignKeys = tagForeignKeys
......@@ -562,11 +451,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
}
// Even it is ignored, also possible to decode db value into the field
if value, ok := field.TagSettings["COLUMN"]; ok {
field.DBName = value
} else {
field.DBName = ToDBName(fieldStruct.Name)
}
field.DBName = ToDBName(fieldStruct.Name)
modelStruct.StructFields = append(modelStruct.StructFields, field)
}
......
......@@ -15,7 +15,7 @@ type QueryFn func(query string) string
// retrieve cached items in a []*SomeModel
// for some reason, reflect.New returns a &[]*SomeModel instead,
// I'm guessing slices can't be interfaces, but pointers to slices can?
func (c *Context) pagedByKeys(conn *sqlite.Conn, keyFieldName string, keys []interface{}, sliceType reflect.Type, search *SearchParams) (reflect.Value, error) {
func (c *Context) fetchPagedByPK(conn *sqlite.Conn, PKDBName string, keys []interface{}, sliceType reflect.Type, search *SearchParams) (reflect.Value, error) {
// actually defaults to 999, but let's get some breathing room
result := reflect.New(sliceType)
resultVal := result.Elem()
......@@ -31,7 +31,7 @@ func (c *Context) pagedByKeys(conn *sqlite.Conn, keyFieldName string, keys []int
}
pageAddr := reflect.New(sliceType)
cond := builder.In(keyFieldName, remainingItems[:pageSize]...)
cond := builder.In(PKDBName, remainingItems[:pageSize]...)
err := c.Select(conn, pageAddr.Interface(), cond, search)
if err != nil {
......@@ -45,3 +45,27 @@ func (c *Context) pagedByKeys(conn *sqlite.Conn, keyFieldName string, keys []int
return result, nil
}
func (c *Context) deletePagedByPK(conn *sqlite.Conn, TableName string, PKDBName string, keys []interface{}) error {
remainingItems := keys
for len(remainingItems) > 0 {
var pageSize int
if len(remainingItems) > maxSqlVars {
pageSize = maxSqlVars
}