Commit a8d6dbc5 authored by Amos Wenger's avatar Amos Wenger

Implement basic automigrate, fix some tests

parent 4b064a2d
Pipeline #10415 failed with stage
in 13 seconds
package hades
import (
"fmt"
"reflect"
"strings"
"crawshaw.io/sqlite"
"github.com/pkg/errors"
)
// TODO: if table already exists, just add fields
func (c *Context) AutoMigrate(conn *sqlite.Conn) error {
for tableName, m := range c.ScopeMap.byDBName {
ms := m.GetModelStruct()
query := fmt.Sprintf("CREATE TABLE %s", tableName)
var columns []string
var pks []string
for _, sf := range ms.StructFields {
if sf.Relationship != nil {
continue
}
var sqliteType string
switch sf.Struct.Type.Kind() {
case reflect.Int64, reflect.Bool:
sqliteType = "INTEGER"
case reflect.Float64:
sqliteType = "REAL"
case reflect.String:
sqliteType = "TEXT"
default:
return errors.Errorf("Unsupported model field type: %v (in model %v)", sf.Struct.Type, ms.ModelType)
}
column := fmt.Sprintf("%s %s", sf.DBName, sqliteType)
columns = append(columns, column)
if sf.IsPrimaryKey {
pks = append(pks, sf.DBName)
}
}
if len(pks) > 0 {
columns = append(columns, fmt.Sprintf("PRIMARY KEY (%s)", strings.Join(pks, ", ")))
} else {
return errors.Errorf("Model %v has no primary keys", ms.ModelType)
}
query = fmt.Sprintf("%s (%s)", query, strings.Join(columns, ", "))
err := c.ExecRaw(conn, query, nil)
if err != nil {
return err
}
}
return nil
}
......@@ -37,14 +37,17 @@ func Test_BelongsTo(t *testing.T) {
ID: 123,
Desc: "Consumer-grade flamethrowers",
}
t.Log("Saving one fate")
wtest.Must(t, c.SaveOne(conn, someFate))
lea := &Human{
ID: 3,
FateID: someFate.ID,
}
t.Log("Saving one human")
wtest.Must(t, c.SaveOne(conn, lea))
t.Log("Preloading lea")
c.Preload(conn, &hades.PreloadParams{
Record: lea,
Fields: []hades.PreloadField{
......@@ -489,11 +492,11 @@ func withContext(t *testing.T, models []interface{}, f WithContextFunc) {
conn := dbpool.Get(context.Background().Done())
defer dbpool.Put(conn)
// whoops, automigrate
// wtest.Must(t, conn.AutoMigrate(models...).Error)
c, err := hades.NewContext(makeConsumer(t), models...)
wtest.Must(t, err)
c.Log = true
wtest.Must(t, c.AutoMigrate(conn))
f(conn, c)
}
......@@ -20,13 +20,16 @@ func (scope *Scope) ToEq(rec reflect.Value) builder.Eq {
eq := make(builder.Eq)
for _, sf := range scope.GetModelStruct().StructFields {
if sf.Relationship != nil {
// TODO: set IDs here?
continue
}
eq[sf.DBName] = recEl.FieldByName(sf.Name).Interface()
}
return eq
}
func (c *Context) Insert(conn *sqlite.Conn, rec reflect.Value) error {
scope := c.NewScope(rec)
func (c *Context) Insert(conn *sqlite.Conn, scope *Scope, rec reflect.Value) error {
eq := scope.ToEq(rec)
return c.Exec(conn, builder.Insert(eq).Into(scope.TableName()), nil)
}
......@@ -105,7 +105,7 @@ func (c *Context) saveJoins(params *SaveParams, conn *sqlite.Conn, mtm *ManyToMa
}
// FIXME: that's slow/bad because of ToEq
err := c.Insert(conn, rec)
err := c.Insert(conn, mtm.Scope, rec)
if err != nil {
return errors.Wrap(err, "creating new relation records")
}
......
......@@ -217,7 +217,7 @@ func (c *Context) saveRows(conn *sqlite.Conn, params *SaveParams, inputIface int
if len(inserts) > 0 {
for _, rec := range inserts {
// FIXME: that's slow/bad because of ToEq
err := c.Insert(conn, rec)
err := c.Insert(conn, scope, rec)
if err != nil {
return errors.Wrap(err, "inserting new DB records")
}
......
......@@ -9,15 +9,13 @@ import (
func (c *Context) Scan(stmt *sqlite.Stmt, columns []string, result reflect.Value) error {
for i, c := range columns {
// FIXME: that's bad/slow
fieldName := FromDBName(c)
field := result.FieldByName(fieldName)
switch field.Type().Kind() {
case reflect.Int64:
case reflect.Int32:
case reflect.Int:
field.SetInt(stmt.ColumnInt64(i))
case reflect.Float32:
case reflect.Float64:
field.SetFloat(stmt.ColumnFloat(i))
case reflect.Bool:
......
package hades
import (
"fmt"
"reflect"
"github.com/pkg/errors"
......@@ -21,16 +20,13 @@ func NewScopeMap() *ScopeMap {
func (sm *ScopeMap) Add(c *Context, m interface{}) error {
val := reflect.ValueOf(m)
fmt.Printf("val = %v, type = %v\n", val, val.Type())
if val.Type().Kind() == reflect.Ptr {
val = val.Elem()
fmt.Printf("was ptr, now val = %v, type = %v\n", val, val.Type())
}
if val.Type().Kind() == reflect.Interface {
val = val.Elem()
fmt.Printf("was interface, now val = %v, type = %v\n", val, val.Type())
}
reflectType := val.Type()
......
......@@ -11,6 +11,9 @@ func (c *Context) Select(conn *sqlite.Conn, result interface{}, cond builder.Con
var columns []string
ms := c.NewScope(result).GetModelStruct()
for _, sf := range ms.StructFields {
if sf.Relationship != nil {
continue
}
columns = append(columns, sf.DBName)
}
......
......@@ -172,7 +172,7 @@ func (c *Context) WalkType(riMap RecordInfoMap, name string, atyp reflect.Type,
return fmt.Errorf("visitField expects a Slice of Ptr, or a Ptr, but got %v", sf.Struct.Type)
}
if c.ScopeMap.ByType(fieldTyp) != nil {
if c.ScopeMap.ByType(fieldTyp) == nil {
if explicit {
return fmt.Errorf("%s.%s is not an explicitly listed model (%v)", ms.ModelType.Name(), sf.Name, fieldTyp)
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment