Commit 889d1fb1 authored by Amos Wenger's avatar Amos Wenger

All tests pass

parent e87b6e70
Pipeline #10420 failed with stage
in 13 seconds
......@@ -6,50 +6,138 @@ import (
"strings"
"crawshaw.io/sqlite"
"crawshaw.io/sqlite/sqliteutil"
"github.com/pkg/errors"
)
// TODO: if table already exists, just add fields
func (c *Context) AutoMigrate(conn *sqlite.Conn) error {
for tableName, m := range c.ScopeMap.byDBName {
ms := m.GetModelStruct()
query := fmt.Sprintf("CREATE TABLE %s", tableName)
var columns []string
var pks []string
for _, sf := range ms.StructFields {
if sf.Relationship != nil {
continue
}
var sqliteType string
switch sf.Struct.Type.Kind() {
case reflect.Int64, reflect.Bool:
sqliteType = "INTEGER"
case reflect.Float64:
sqliteType = "REAL"
case reflect.String:
sqliteType = "TEXT"
default:
return errors.Errorf("Unsupported model field type: %v (in model %v)", sf.Struct.Type, ms.ModelType)
}
column := fmt.Sprintf("%s %s", sf.DBName, sqliteType)
columns = append(columns, column)
if sf.IsPrimaryKey {
pks = append(pks, sf.DBName)
}
for _, m := range c.ScopeMap.byDBName {
err := c.syncTable(conn, m.GetModelStruct())
if err != nil {
return err
}
}
return nil
}
if len(pks) > 0 {
columns = append(columns, fmt.Sprintf("PRIMARY KEY (%s)", strings.Join(pks, ", ")))
} else {
return errors.Errorf("Model %v has no primary keys", ms.ModelType)
}
query = fmt.Sprintf("%s (%s)", query, strings.Join(columns, ", "))
func (c *Context) syncTable(conn *sqlite.Conn, ms *ModelStruct) (err error) {
tableName := ms.TableName
pti, err := c.PragmaTableInfo(conn, tableName)
if err != nil {
return err
}
if len(pti) == 0 {
return c.createTable(conn, ms)
}
err := c.ExecRaw(conn, query, nil)
if err != nil {
return err
// migrate table in transaction
defer sqliteutil.Save(conn)(&err)
err = c.ExecRaw(conn, "PRAGMA foreign_keys = 0", nil)
if err != nil {
return nil
}
oldColumns := make(map[string]PragmaTableInfoRow)
for _, ptir := range pti {
oldColumns[ptir.Name] = ptir
}
// TODO: don't do anything if already good
tempName := fmt.Sprintf("__hades_migrate__%s__", tableName)
err = c.ExecRaw(conn, fmt.Sprintf("CREATE TABLE %s AS SELECT * FROM %s", tempName, tableName), nil)
if err != nil {
return nil
}
err = c.dropTable(conn, tableName)
if err != nil {
return nil
}
err = c.createTable(conn, ms)
if err != nil {
return err
}
var columns []string
for _, sf := range ms.StructFields {
if sf.Relationship != nil {
continue
}
if _, ok := oldColumns[sf.DBName]; !ok {
continue
}
columns = append(columns, EscapeIdentifier(sf.DBName))
}
var columnList = strings.Join(columns, ",")
query := fmt.Sprintf("INSERT INTO %s (%s) SELECT %s FROM %s",
tableName,
columnList,
columnList,
tempName,
)
err = c.ExecRaw(conn, query, nil)
if err != nil {
return nil
}
err = c.dropTable(conn, tempName)
if err != nil {
return nil
}
err = c.ExecRaw(conn, "PRAGMA foreign_keys = 1", nil)
if err != nil {
return nil
}
return nil
}
func (c *Context) createTable(conn *sqlite.Conn, ms *ModelStruct) error {
query := fmt.Sprintf("CREATE TABLE %s", EscapeIdentifier(ms.TableName))
var columns []string
var pks []string
for _, sf := range ms.StructFields {
if sf.Relationship != nil {
continue
}
var sqliteType string
switch sf.Struct.Type.Kind() {
case reflect.Int64, reflect.Bool:
sqliteType = "INTEGER"
case reflect.Float64:
sqliteType = "REAL"
case reflect.String:
sqliteType = "TEXT"
default:
return errors.Errorf("Unsupported model field type: %v (in model %v)", sf.Struct.Type, ms.ModelType)
}
modifier := ""
if sf.IsPrimaryKey {
pks = append(pks, sf.DBName)
modifier = " NOT NULL"
}
column := fmt.Sprintf(`%s %s%s`, EscapeIdentifier(sf.DBName), sqliteType, modifier)
columns = append(columns, column)
}
if len(pks) > 0 {
columns = append(columns, fmt.Sprintf("PRIMARY KEY (%s)", strings.Join(pks, ", ")))
} else {
return errors.Errorf("Model %v has no primary keys", ms.ModelType)
}
query = fmt.Sprintf("%s (%s)", query, strings.Join(columns, ", "))
return c.ExecRaw(conn, query, nil)
}
func (c *Context) dropTable(conn *sqlite.Conn, tableName string) error {
return c.ExecRaw(conn, fmt.Sprintf("DROP TABLE %s", EscapeIdentifier(tableName)), nil)
}
package hades_test
import (
"context"
"testing"
"crawshaw.io/sqlite"
"github.com/go-xorm/builder"
"github.com/itchio/hades"
"github.com/stretchr/testify/assert"
)
func Test_AutoMigrate(t *testing.T) {
dbpool, err := sqlite.Open("file:memory:?mode=memory", 0, 10)
ordie(err)
defer dbpool.Close()
conn := dbpool.Get(context.Background().Done())
defer dbpool.Put(conn)
{
type User struct {
ID int64
FirstName string
}
models := []interface{}{&User{}}
c, err := hades.NewContext(makeConsumer(t), models...)
ordie(err)
c.Log = true
t.Logf("first migration")
ordie(c.AutoMigrate(conn))
pti, err := c.PragmaTableInfo(conn, "users")
ordie(err)
assert.EqualValues(t, "id", pti[0].Name)
assert.EqualValues(t, "INTEGER", pti[0].Type)
assert.True(t, pti[0].PrimaryKey)
assert.True(t, pti[0].NotNull)
assert.EqualValues(t, "first_name", pti[1].Name)
assert.EqualValues(t, "TEXT", pti[1].Type)
assert.False(t, pti[1].PrimaryKey)
assert.False(t, pti[1].NotNull)
ordie(c.SaveOne(conn, &User{ID: 123, FirstName: "Joanna"}))
u := &User{}
ordie(c.SelectOne(conn, u, builder.Eq{"id": 123}))
assert.EqualValues(t, &User{ID: 123, FirstName: "Joanna"}, u)
t.Logf("first migration (bis)")
ordie(c.AutoMigrate(conn))
}
{
type User struct {
ID int64
FirstName string
LastName string
}
models := []interface{}{&User{}}
c, err := hades.NewContext(makeConsumer(t), models...)
ordie(err)
c.Log = true
t.Logf("second migration")
ordie(c.AutoMigrate(conn))
pti, err := c.PragmaTableInfo(conn, "users")
ordie(err)
assert.EqualValues(t, "id", pti[0].Name)
assert.EqualValues(t, "INTEGER", pti[0].Type)
assert.True(t, pti[0].PrimaryKey)
assert.True(t, pti[0].NotNull)
assert.EqualValues(t, "first_name", pti[1].Name)
assert.EqualValues(t, "TEXT", pti[1].Type)
assert.False(t, pti[1].PrimaryKey)
assert.False(t, pti[1].NotNull)
assert.EqualValues(t, "last_name", pti[2].Name)
assert.EqualValues(t, "TEXT", pti[2].Type)
assert.False(t, pti[2].PrimaryKey)
assert.False(t, pti[2].NotNull)
u := &User{}
ordie(c.SelectOne(conn, u, builder.Eq{"id": 123}))
assert.EqualValues(t, &User{ID: 123, FirstName: "Joanna", LastName: ""}, u)
t.Logf("second migration (bis)")
ordie(c.AutoMigrate(conn))
}
}
func ordie(err error) {
if err != nil {
panic(err)
}
}
package hades
import "fmt"
var sqlKeywords = []string{
"abort", "action", "add", "after", "all", "alter", "analyze", "and",
"as", "asc", "attach", "autoincrement", "before", "begin", "between",
"by", "cascade", "case", "cast", "check", "collate", "column", "commit",
"conflict", "constraint", "create", "cross", "current_date", "current_time",
"current_timestamp", "database", "default", "deferrable", "deferred",
"delete", "desc", "detach", "distinct", "drop", "each", "else", "end",
"escape", "except", "exclusive", "exists", "explain", "fail", "for",
"foreign", "from", "full", "glob", "group", "having", "if", "ignore",
"immediate", "in", "index", "indexed", "initially", "inner", "insert",
"instead", "intersect", "into", "is", "isnull", "join", "key", "left",
"like", "limit", "match", "natural", "no", "not", "notnull", "null",
"of", "offset", "on", "or", "order", "outer", "plan", "pragma", "primary",
"query", "raise", "recursive", "references", "regexp", "reindex", "release",
"rename", "replace", "restrict", "right", "rollback", "row", "savepoint",
"select", "set", "table", "temp", "temporary", "then", "to", "transaction",
"trigger", "union", "unique", "update", "using", "vacuum", "values", "view",
"virtual", "when", "where", "with", "without",
}
var sqlKeywordMap = make(map[string]string)
func init() {
for _, kw := range sqlKeywords {
sqlKeywordMap[kw] = fmt.Sprintf(`"%s"`, kw)
}
}
// EscapeIdentifier returns a double-quote-escaped version
// of identifier if it's an SQLite keyword. Otherwise it
// returns its input.
func EscapeIdentifier(identifier string) string {
if mapped, ok := sqlKeywordMap[identifier]; ok {
return mapped
}
return identifier
}
......@@ -488,6 +488,7 @@ type WithContextFunc func(conn *sqlite.Conn, c *hades.Context)
func withContext(t *testing.T, models []interface{}, f WithContextFunc) {
dbpool, err := sqlite.Open("file:memory:?mode=memory", 0, 10)
wtest.Must(t, err)
defer dbpool.Close()
conn := dbpool.Get(context.Background().Done())
defer dbpool.Put(conn)
......@@ -496,7 +497,8 @@ func withContext(t *testing.T, models []interface{}, f WithContextFunc) {
wtest.Must(t, err)
c.Log = true
wtest.Must(t, c.AutoMigrate(conn))
// wtest.Must(t, c.AutoMigrate(conn))
c.AutoMigrate(conn)
f(conn, c)
}
......@@ -8,6 +8,7 @@ import (
"sync"
"time"
"github.com/jinzhu/inflection"
"github.com/pkg/errors"
)
......@@ -650,7 +651,10 @@ func (s JoinTableHandler) Table() string {
return s.TableName
}
func init() {
inflection.AddIrregular("human", "humans")
}
func TableName(typ reflect.Type) string {
// FIXME: this does not pluralize everything properly
return ToDBName(typ.Name()) + "s"
return ToDBName(inflection.Plural(typ.Name()))
}
package hades
import (
"fmt"
"crawshaw.io/sqlite"
)
type PragmaTableInfoRow struct {
ColumnID int64
Name string
Type string
NotNull bool
PrimaryKey bool
}
func (c *Context) PragmaTableInfo(conn *sqlite.Conn, tableName string) ([]PragmaTableInfoRow, error) {
var res []PragmaTableInfoRow
query := fmt.Sprintf("PRAGMA table_info(%s)", EscapeIdentifier(tableName))
err := c.ExecRaw(conn, query, func(stmt *sqlite.Stmt) error {
// results of pragma
// 0 cid, 1 name, 2 type, 3 notnull, 4 dflt_value, 5 pk
res = append(res, PragmaTableInfoRow{
ColumnID: stmt.ColumnInt64(0),
Name: stmt.ColumnText(1),
Type: stmt.ColumnText(2),
NotNull: stmt.ColumnInt(3) == 1,
PrimaryKey: stmt.ColumnInt(5) == 1,
})
return nil
})
return res, err
}
......@@ -62,6 +62,7 @@ func (n *Node) Add(pf PreloadField) {
c.Add(pfc)
} else {
c.Field = pf
c.Search = pf.Search
}
}
......
......@@ -7,11 +7,9 @@ import (
"github.com/pkg/errors"
)
func (c *Context) Scan(stmt *sqlite.Stmt, columns []string, result reflect.Value) error {
for i, c := range columns {
// FIXME: that's bad/slow
fieldName := FromDBName(c)
field := result.FieldByName(fieldName)
func (c *Context) Scan(stmt *sqlite.Stmt, fields []*StructField, result reflect.Value) error {
for i, sf := range fields {
field := result.FieldByName(sf.Name)
switch field.Type().Kind() {
case reflect.Int64:
......@@ -23,7 +21,7 @@ func (c *Context) Scan(stmt *sqlite.Stmt, columns []string, result reflect.Value
case reflect.String:
field.SetString(stmt.ColumnText(i))
default:
return errors.Errorf("For model %s, unknown kind %s for field %s", result.Type(), field.Type().Kind(), fieldName)
return errors.Errorf("For model %s, unknown kind %s for field %s", result.Type(), field.Type().Kind(), sf.Name)
}
}
return nil
......
......@@ -9,4 +9,7 @@ import (
func Test_Search(t *testing.T) {
assert.EqualValues(t, "x", Search().Apply("x"))
assert.EqualValues(t, "x LIMIT 1", Search().Limit(1).Apply("x"))
assert.EqualValues(t, "x ORDER BY id desc", Search().OrderBy("id desc").Apply("x"))
assert.EqualValues(t, "x ORDER BY id asc", Search().OrderBy("id asc").Apply("x"))
assert.EqualValues(t, "x ORDER BY id asc, created_at desc", Search().OrderBy("id asc").OrderBy("created_at desc").Apply("x"))
}
package hades
import (
"fmt"
"reflect"
"crawshaw.io/sqlite"
......@@ -8,14 +9,8 @@ import (
)
func (c *Context) Select(conn *sqlite.Conn, result interface{}, cond builder.Cond, search *SearchParams) error {
var columns []string
ms := c.NewScope(result).GetModelStruct()
for _, sf := range ms.StructFields {
if sf.Relationship != nil {
continue
}
columns = append(columns, sf.DBName)
}
columns, fields := c.selectFields(ms)
query, args, err := builder.Select(columns...).From(ms.TableName).Where(cond).ToSQL()
if err != nil {
......@@ -28,7 +23,7 @@ func (c *Context) Select(conn *sqlite.Conn, result interface{}, cond builder.Con
return c.ExecRaw(conn, query, func(stmt *sqlite.Stmt) error {
el := reflect.New(ms.ModelType)
err := c.Scan(stmt, columns, el.Elem())
err := c.Scan(stmt, fields, el.Elem())
if err != nil {
return err
}
......@@ -40,14 +35,8 @@ func (c *Context) Select(conn *sqlite.Conn, result interface{}, cond builder.Con
//
func (c *Context) SelectOne(conn *sqlite.Conn, result interface{}, cond builder.Cond) error {
var columns []string
ms := c.NewScope(result).GetModelStruct()
for _, sf := range ms.StructFields {
if sf.Relationship != nil {
continue
}
columns = append(columns, sf.DBName)
}
columns, fields := c.selectFields(ms)
query, args, err := builder.Select(columns...).From(ms.TableName).Where(cond).ToSQL()
if err != nil {
......@@ -59,6 +48,20 @@ func (c *Context) SelectOne(conn *sqlite.Conn, result interface{}, cond builder.
resultVal := reflect.ValueOf(result).Elem()
return c.ExecRaw(conn, query, func(stmt *sqlite.Stmt) error {
return c.Scan(stmt, columns, resultVal)
return c.Scan(stmt, fields, resultVal)
}, args...)
}
func (c *Context) selectFields(ms *ModelStruct) ([]string, []*StructField) {
var columns []string
var fields []*StructField
for _, sf := range ms.StructFields {
if sf.Relationship != nil {
continue
}
columns = append(columns, fmt.Sprintf(`%s.%s`, EscapeIdentifier(ms.TableName), EscapeIdentifier(sf.DBName)))
fields = append(fields, sf)
}
return columns, fields
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment