Commit 056d011f authored by Amos Wenger's avatar Amos Wenger

null value handling

parent 055d01da
Pipeline #10442 passed with stage
in 2 minutes and 12 seconds
......@@ -127,15 +127,23 @@ func (c *Context) createTable(conn *sqlite.Conn, ms *ModelStruct) error {
}
var sqliteType string
switch sf.Struct.Type.Kind() {
case reflect.Int64, reflect.Bool:
typ := sf.Struct.Type
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
switch typ.Kind() {
case reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8, reflect.Int,
reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8, reflect.Uint:
sqliteType = "INTEGER"
case reflect.Float64:
case reflect.Bool:
sqliteType = "BOOLEAN"
case reflect.Float64, reflect.Float32:
sqliteType = "REAL"
case reflect.String:
sqliteType = "TEXT"
case reflect.Struct:
if sf.Struct.Type == reflect.TypeOf(time.Time{}) {
if typ == reflect.TypeOf(time.Time{}) {
sqliteType = "DATETIME"
break
}
......
......@@ -163,7 +163,7 @@ func Test_AutoMigrateAllValidTypes(t *testing.T) {
assert.False(t, pti[1].NotNull)
assert.EqualValues(t, "alive", pti[2].Name)
assert.EqualValues(t, "INTEGER", pti[2].Type)
assert.EqualValues(t, "BOOLEAN", pti[2].Type)
assert.False(t, pti[2].PrimaryKey)
assert.False(t, pti[2].NotNull)
......@@ -176,6 +176,25 @@ func Test_AutoMigrateAllValidTypes(t *testing.T) {
assert.EqualValues(t, "DATETIME", pti[4].Type)
assert.False(t, pti[4].PrimaryKey)
assert.False(t, pti[4].NotNull)
tim := time.Now()
h1 := &Humanoid{
ID: 12,
Alive: true,
BornAt: tim,
FirstName: "Jeremy",
HeartRate: 3.14,
}
ordie(c.SaveOne(conn, h1))
h2 := &Humanoid{}
ordie(c.SelectOne(conn, h2, builder.Eq{"id": 12}))
assert.EqualValues(t, h1.ID, h2.ID)
assert.EqualValues(t, h1.Alive, h2.Alive)
assert.EqualValues(t, h1.BornAt.Format(time.RFC3339Nano), h2.BornAt.Format(time.RFC3339Nano))
assert.EqualValues(t, h1.FirstName, h2.FirstName)
assert.EqualValues(t, h1.HeartRate, h2.HeartRate)
}
func ordie(err error) {
......
package hades
import (
"reflect"
"time"
)
func DBValue(x interface{}) interface{} {
typ := reflect.TypeOf(x)
value := reflect.ValueOf(x)
wasPtr := false
if typ.Kind() == reflect.Ptr {
if value.IsNil() {
return nil
}
wasPtr = true
typ = typ.Elem()
value = value.Elem()
}
switch typ.Kind() {
case reflect.Bool:
if value.Bool() {
return 1
}
return 0
case reflect.Struct:
if typ == reflect.TypeOf(time.Time{}) {
return value.Interface().(time.Time).Format(time.RFC3339Nano)
}
}
if wasPtr {
return value.Interface()
}
return x
}
package hades_test
import (
"testing"
"time"
"github.com/itchio/hades"
"github.com/stretchr/testify/assert"
)
func Test_DBValue(t *testing.T) {
var s *string = nil
assert.Nil(t, hades.DBValue(s))
tim := time.Now()
assert.EqualValues(t, tim.Format(time.RFC3339Nano), hades.DBValue(tim))
assert.EqualValues(t, 42, hades.DBValue(42))
assert.EqualValues(t, 3.14, hades.DBValue(3.14))
}
package hades
import (
"fmt"
"reflect"
"time"
......@@ -47,9 +46,9 @@ func DiffRecord(x, y interface{}, scope *Scope) (ChangedFields, error) {
v1f := v1.Field(i)
v2f := v2.Field(i)
iseq, err := eq(v1f, v2f)
iseq, err := iseq(sf, v1f, v2f)
if err != nil {
return nil, errors.Wrap(err, "while comparing fields")
return res, err
}
if !iseq {
......@@ -63,121 +62,56 @@ func DiffRecord(x, y interface{}, scope *Scope) (ChangedFields, error) {
return res, nil
}
// Comparison.
// Taken from text/template
func iseq(sf *StructField, v1f reflect.Value, v2f reflect.Value) (bool, error) {
typ := sf.Struct.Type
originalTyp := typ
var (
errBadComparison = errors.New("incompatible types for comparison")
errNoComparison = errors.New("missing argument for comparison")
)
if typ.Kind() == reflect.Ptr {
if v1f.IsNil() {
if !v2f.IsNil() {
return false, nil // only v1 nil
}
return true, nil // both nil
} else {
if v2f.IsNil() {
return false, nil // only v2 nil
}
type kind int
const (
invalidKind kind = iota
boolKind
complexKind
intKind
floatKind
stringKind
uintKind
timeKind
)
// neither are nil, let's compare values
typ = typ.Elem()
v1f = v1f.Elem()
v2f = v2f.Elem()
}
}
func basicKind(v reflect.Value) (kind, error) {
switch v.Kind() {
switch typ.Kind() {
case reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8, reflect.Int,
reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8, reflect.Uint:
eq := v1f.Int() == v2f.Int()
return eq, nil
case reflect.Bool:
return boolKind, nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return intKind, nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return uintKind, nil
case reflect.Float32, reflect.Float64:
return floatKind, nil
case reflect.Complex64, reflect.Complex128:
return complexKind, nil
eq := v1f.Bool() == v2f.Bool()
return eq, nil
case reflect.Float64, reflect.Float32:
eq := v1f.Float() == v2f.Float()
return eq, nil
case reflect.String:
return stringKind, nil
eq := v1f.String() == v2f.String()
return eq, nil
case reflect.Struct:
if _, ok := v.Interface().(time.Time); ok {
return timeKind, nil
}
}
return invalidKind, fmt.Errorf("bad type for comparison: %v", v.Type())
}
// eq evaluates the comparison a == b || a == c || ...
func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) {
v1 := indirectInterface(arg1)
k1, err := basicKind(v1)
if err != nil {
return false, err
}
if len(arg2) == 0 {
return false, errNoComparison
}
for _, arg := range arg2 {
v2 := indirectInterface(arg)
k2, err := basicKind(v2)
if err != nil {
return false, err
}
truth := false
if k1 != k2 {
// Special case: Can compare integer values regardless of type's sign.
switch {
case k1 == intKind && k2 == uintKind:
truth = v1.Int() >= 0 && uint64(v1.Int()) == v2.Uint()
case k1 == uintKind && k2 == intKind:
truth = v2.Int() >= 0 && v1.Uint() == uint64(v2.Int())
default:
return false, errBadComparison
}
} else {
switch k1 {
case boolKind:
truth = v1.Bool() == v2.Bool()
case complexKind:
truth = v1.Complex() == v2.Complex()
case floatKind:
truth = v1.Float() == v2.Float()
case intKind:
truth = v1.Int() == v2.Int()
case stringKind:
truth = v1.String() == v2.String()
case uintKind:
truth = v1.Uint() == v2.Uint()
case timeKind:
truth = v1.Interface().(time.Time) == v2.Interface().(time.Time)
default:
panic("invalid kind")
}
}
if truth {
return true, nil
if typ == reflect.TypeOf(time.Time{}) {
eq := v1f.Interface().(time.Time).UnixNano() == v2f.Interface().(time.Time).UnixNano()
return eq, nil
}
}
return false, nil
}
// indirectInterface returns the concrete value in an interface value,
// or else the zero reflect.Value.
// That is, if v represents the interface value x, the result is the same as reflect.ValueOf(x):
// the fact that x was an interface value is forgotten.
func indirectInterface(v reflect.Value) reflect.Value {
if v.Kind() != reflect.Interface {
return v
}
if v.IsNil() {
return reflect.Value{}
}
return v.Elem()
return false, errors.Errorf("Don't know how to compare fields of type %v", originalTyp)
}
func (cf ChangedFields) ToEq() builder.Eq {
eq := make(builder.Eq)
for sf, v := range cf {
eq[sf.DBName] = v
eq[sf.DBName] = DBValue(v)
}
return eq
}
......@@ -23,7 +23,7 @@ func (scope *Scope) ToEq(rec reflect.Value) builder.Eq {
if !sf.IsNormal {
continue
}
eq[sf.DBName] = recEl.FieldByName(sf.Name).Interface()
eq[sf.DBName] = DBValue(recEl.FieldByName(sf.Name).Interface())
}
return eq
}
......
package hades_test
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"crawshaw.io/sqlite"
"github.com/go-xorm/builder"
"github.com/itchio/hades"
"github.com/itchio/wharf/state"
)
func Test_Null(t *testing.T) {
consumer := &state.Consumer{
OnMessage: func(lvl string, message string) {
t.Logf("[%s] %s", lvl, message)
},
}
type Download struct {
ID int64
FinishedAt *time.Time
ErrorCode *int64
ErrorMessage *string
}
c, err := hades.NewContext(consumer, &Download{})
if err != nil {
panic(err)
}
c.Log = true
sqlite.Logger = func(code sqlite.ErrorCode, msg []byte) {
t.Logf("[SQLITE] %d %s", code, string(msg))
}
dbpool, err := sqlite.Open("file:memory:?mode=memory", 0, 10)
if err != nil {
panic(err)
}
conn := dbpool.Get(context.Background().Done())
defer dbpool.Put(conn)
ordie(c.AutoMigrate(conn))
{
d := &Download{
ID: 123,
}
ordie(c.SaveOne(conn, d))
{
dd := &Download{}
ordie(c.SelectOne(conn, dd, builder.Eq{"id": 123}))
assert.EqualValues(t, 123, dd.ID)
assert.Nil(t, dd.FinishedAt)
assert.Nil(t, dd.ErrorCode)
assert.Nil(t, dd.ErrorMessage)
}
errMsg := "No rest for the wicked"
d.ErrorMessage = &errMsg
errCode := int64(9000)
d.ErrorCode = &errCode
finishedAt := time.Now()
d.FinishedAt = &finishedAt
ordie(c.SaveOne(conn, d))
{
dd := &Download{}
ordie(c.SelectOne(conn, dd, builder.Eq{"id": 123}))
assert.EqualValues(t, 123, dd.ID)
assert.EqualValues(t, *d.ErrorMessage, *dd.ErrorMessage)
assert.EqualValues(t, *d.ErrorCode, *dd.ErrorCode)
assert.EqualValues(t, (*d.FinishedAt).Format(time.RFC3339Nano), (*dd.FinishedAt).Format(time.RFC3339Nano))
}
d.ErrorMessage = nil
ordie(c.SaveOne(conn, d))
{
dd := &Download{}
ordie(c.SelectOne(conn, dd, builder.Eq{"id": 123}))
assert.EqualValues(t, 123, dd.ID)
assert.Nil(t, dd.ErrorMessage)
assert.EqualValues(t, *d.ErrorCode, *dd.ErrorCode)
assert.EqualValues(t, (*d.FinishedAt).Format(time.RFC3339Nano), (*dd.FinishedAt).Format(time.RFC3339Nano))
}
}
}
......@@ -2,24 +2,78 @@ package hades
import (
"reflect"
"time"
"crawshaw.io/sqlite"
"github.com/pkg/errors"
)
func (c *Context) Scan(stmt *sqlite.Stmt, fields []*StructField, result reflect.Value) error {
for i, sf := range fields {
func (c *Context) Scan(stmt *sqlite.Stmt, structFields []*StructField, result reflect.Value) error {
for i, sf := range structFields {
field := result.FieldByName(sf.Name)
switch field.Type().Kind() {
case reflect.Int64:
field.SetInt(stmt.ColumnInt64(i))
case reflect.Float64:
field.SetFloat(stmt.ColumnFloat(i))
fieldEl := field
typ := field.Type()
wasPtr := false
colTyp := stmt.ColumnType(i)
if typ.Kind() == reflect.Ptr {
wasPtr = true
if colTyp == sqlite.SQLITE_NULL {
field.Set(reflect.Zero(field.Type()))
continue
}
fieldEl = field.Elem()
typ = typ.Elem()
}
switch typ.Kind() {
case reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8, reflect.Int,
reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8, reflect.Uint:
val := stmt.ColumnInt64(i)
if wasPtr {
field.Set(reflect.ValueOf(&val))
} else {
fieldEl.SetInt(val)
}
case reflect.Float64, reflect.Float32:
val := stmt.ColumnFloat(i)
if wasPtr {
field.Set(reflect.ValueOf(&val))
} else {
fieldEl.SetFloat(val)
}
case reflect.Bool:
field.SetBool(stmt.ColumnInt(i) == 1)
val := stmt.ColumnInt(i) == 1
if wasPtr {
field.Set(reflect.ValueOf(&val))
} else {
fieldEl.SetBool(val)
}
case reflect.String:
field.SetString(stmt.ColumnText(i))
val := stmt.ColumnText(i)
if wasPtr {
field.Set(reflect.ValueOf(&val))
} else {
fieldEl.SetString(val)
}
case reflect.Struct:
if typ == reflect.TypeOf(time.Time{}) {
text := stmt.ColumnText(i)
tim, err := time.Parse(time.RFC3339Nano, text)
if err == nil {
if wasPtr {
field.Set(reflect.ValueOf(&tim))
} else {
fieldEl.Set(reflect.ValueOf(tim))
}
}
break
}
fallthrough
default:
return errors.Errorf("For model %s, unknown kind %s for field %s", result.Type(), field.Type().Kind(), sf.Name)
}
......
......@@ -12,11 +12,6 @@ import (
"github.com/stretchr/testify/assert"
)
type Honor struct {
ID int64
Title string
}
func Test_Select(t *testing.T) {
consumer := &state.Consumer{
OnMessage: func(lvl string, message string) {
......@@ -24,6 +19,11 @@ func Test_Select(t *testing.T) {
},
}
type Honor struct {
ID int64
Title string
}
c, err := hades.NewContext(consumer, &Honor{})
if err != nil {
panic(err)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment