Commit d45f9ef2 authored by Amos Wenger's avatar Amos Wenger

Import a bunch of code from gorm

parent 41442d0f
Pipeline #10406 failed with stage
in 20 seconds
package hades
import (
"reflect"
"github.com/itchio/wharf/state"
"github.com/jinzhu/gorm"
"github.com/pkg/errors"
)
type ScopeMap map[reflect.Type]*gorm.Scope
func (sm ScopeMap) ByDBName(dbName string) *gorm.Scope {
for _, s := range sm {
if s.TableName() == dbName {
return s
}
}
return nil
}
type Context struct {
ScopeMap *ScopeMap
Consumer *state.Consumer
ScopeMap ScopeMap
Stats Stats
Error error
}
type Stats struct {
......@@ -32,35 +18,32 @@ type Stats struct {
Current int64
}
func NewContext(db *gorm.DB, models []interface{}, consumer *state.Consumer) *Context {
scopeMap := make(ScopeMap)
for _, m := range models {
mtyp := reflect.TypeOf(m)
scopeMap[mtyp] = db.NewScope(m)
}
func NewContext(consumer *state.Consumer, models ...interface{}) (*Context, error) {
if consumer == nil {
consumer = &state.Consumer{}
}
return &Context{
c := &Context{
Consumer: consumer,
ScopeMap: scopeMap,
ScopeMap: NewScopeMap(),
}
}
type InTransactionFunc func(c *Context, tx *gorm.DB) error
for _, m := range models {
err := c.ScopeMap.Add(c, m)
if err != nil {
return nil, err
}
}
func (c *Context) InTransaction(db *gorm.DB, itf InTransactionFunc) error {
tx := db.Begin()
return c, nil
}
err := itf(c, tx)
if err != nil {
tx.Rollback()
return errors.Wrap(err, "in db transaction")
} else {
tx.Commit()
func (c *Context) NewScope(value interface{}) *Scope {
return &Scope{
Value: value,
ctx: c,
}
}
return nil
func (c *Context) AddError(err error) {
c.Error = err
}
package hades
import "github.com/pkg/errors"
var (
ErrUnaddressable = errors.New("using unaddressable value")
)
package hades
import (
"fmt"
"reflect"
"github.com/pkg/errors"
)
// Field model field definition
type Field struct {
*StructField
IsBlank bool
Field reflect.Value
}
// Set set a value to the field
func (field *Field) Set(value interface{}) (err error) {
if !field.Field.IsValid() {
return errors.New("field value not valid")
}
if !field.Field.CanAddr() {
return ErrUnaddressable
}
reflectValue, ok := value.(reflect.Value)
if !ok {
reflectValue = reflect.ValueOf(value)
}
fieldValue := field.Field
if reflectValue.IsValid() {
if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
} else {
if fieldValue.Kind() == reflect.Ptr {
if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.Struct.Type.Elem()))
}
fieldValue = fieldValue.Elem()
}
if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
} else {
err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type())
}
}
} else {
field.Field.Set(reflect.Zero(field.Field.Type()))
}
field.IsBlank = isBlank(field.Field)
return err
}
package hades
import "reflect"
// JoinTableHandlerInterface is an interface for how to handle many2many relations
type JoinTableHandlerInterface interface {
// initialize join table handler
Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
// Table return join table's table name
Table() string
// SourceForeignKeys return source foreign keys
SourceForeignKeys() []JoinTableForeignKey
// DestinationForeignKeys return destination foreign keys
DestinationForeignKeys() []JoinTableForeignKey
}
// JoinTableForeignKey join table foreign key struct
type JoinTableForeignKey struct {
DBName string
AssociationDBName string
}
// JoinTableSource is a struct that contains model type and foreign keys
type JoinTableSource struct {
ModelType reflect.Type
ForeignKeys []JoinTableForeignKey
}
This diff is collapsed.
......@@ -4,7 +4,8 @@ import (
"fmt"
"reflect"
"github.com/jinzhu/gorm"
"crawshaw.io/sqlite"
"crawshaw.io/sqlite/sqliteutil"
"github.com/pkg/errors"
)
......@@ -22,13 +23,13 @@ type SaveParams struct {
PartialJoins []string
}
func (c *Context) Save(db *gorm.DB, params *SaveParams) error {
return c.InTransaction(db, func(c *Context, tx *gorm.DB) error {
return c.SaveNoTransaction(tx, params)
})
func (c *Context) Save(conn *sqlite.Conn, params *SaveParams) (err error) {
defer sqliteutil.Save(conn)(&err)
return c.SaveNoTransaction(conn, params)
}
func (c *Context) SaveNoTransaction(tx *gorm.DB, params *SaveParams) error {
func (c *Context) SaveNoTransaction(conn *sqlite.Conn, params *SaveParams) error {
if params == nil {
return errors.New("Save: params cannot be nil")
}
......@@ -53,9 +54,6 @@ func (c *Context) SaveNoTransaction(tx *gorm.DB, params *SaveParams) error {
entities := make(AllEntities)
addEntity := func(v reflect.Value) error {
typ := v.Type()
if _, ok := c.ScopeMap[typ]; !ok {
return fmt.Errorf("not a model type: %s", typ)
}
entities[typ] = append(entities[typ], v.Interface())
return nil
}
......@@ -177,7 +175,7 @@ func (c *Context) SaveNoTransaction(tx *gorm.DB, params *SaveParams) error {
}
for _, m := range entities {
err := c.saveRows(tx, params, m)
err := c.saveRows(conn, params, m)
if err != nil {
return errors.Wrap(err, "saving rows")
}
......@@ -185,7 +183,7 @@ func (c *Context) SaveNoTransaction(tx *gorm.DB, params *SaveParams) error {
for _, ri := range riMap {
if ri.ManyToMany != nil {
err := c.saveJoins(params, tx, ri.ManyToMany)
err := c.saveJoins(params, conn, ri.ManyToMany)
if err != nil {
return errors.Wrap(err, "saving joins")
}
......
package hades
import (
"fmt"
"reflect"
"github.com/jinzhu/gorm"
"crawshaw.io/sqlite"
"github.com/pkg/errors"
)
func (c *Context) saveJoins(params *SaveParams, tx *gorm.DB, mtm *ManyToMany) error {
partial := false
for _, pj := range params.PartialJoins {
if mtm.JoinTable == gorm.ToDBName(pj) {
partial = true
}
}
joinType := reflect.PtrTo(mtm.Scope.GetModelStruct().ModelType)
getDestinKey := func(v reflect.Value) interface{} {
return v.Elem().FieldByName(mtm.DestinName).Interface()
}
for sourceKey, joinRecs := range mtm.Values {
cacheAddr := reflect.New(reflect.SliceOf(joinType))
err := tx.Where(
fmt.Sprintf(`"%s" = ?`, mtm.SourceDBName),
sourceKey,
).Find(cacheAddr.Interface()).Error
if err != nil {
return errors.Wrap(err, "fetching cached records to compare later")
}
cache := cacheAddr.Elem()
cacheByDestinKey := make(map[interface{}]reflect.Value)
for i := 0; i < cache.Len(); i++ {
rec := cache.Index(i)
cacheByDestinKey[getDestinKey(rec)] = rec
}
freshByDestinKey := make(map[interface{}]reflect.Value)
for _, joinRec := range joinRecs {
freshByDestinKey[joinRec.DestinKey] = joinRec.Record
}
var deletes []interface{}
updates := make(map[interface{}]ChangedFields)
var inserts []JoinRec
// compare with cache: will result in delete or update
for i := 0; i < cache.Len(); i++ {
crec := cache.Index(i)
destinKey := getDestinKey(crec)
if frec, ok := freshByDestinKey[destinKey]; ok {
if frec.IsValid() {
// compare to maybe update
ifrec := frec.Elem().Interface()
icrec := crec.Elem().Interface()
cf, err := DiffRecord(ifrec, icrec, mtm.Scope)
if err != nil {
return errors.Wrap(err, "diffing database records")
}
if cf != nil {
updates[destinKey] = cf
}
}
} else {
deletes = append(deletes, destinKey)
}
}
for _, joinRec := range joinRecs {
if _, ok := cacheByDestinKey[joinRec.DestinKey]; !ok {
inserts = append(inserts, joinRec)
}
}
if partial {
// Not deleting extra join records, as requested
} else {
if len(deletes) > 0 {
// FIXME: this needs to be paginated to avoid hitting SQLite max variables
rec := reflect.New(joinType.Elem())
err := tx.
Delete(
rec.Interface(),
fmt.Sprintf(
`"%s" = ? and "%s" in (?)`,
mtm.SourceDBName,
mtm.DestinDBName,
),
sourceKey,
deletes,
).Error
if err != nil {
return errors.Wrap(err, "deleting extraneous relations")
}
}
}
for _, joinRec := range inserts {
rec := joinRec.Record
if !rec.IsValid() {
// if not passed an explicit record, make it ourselves
// that typically means the join table doesn't have additional
// columns and is a simple many2many
rec = reflect.New(joinType.Elem())
rec.Elem().FieldByName(mtm.SourceName).Set(reflect.ValueOf(sourceKey))
rec.Elem().FieldByName(mtm.DestinName).Set(reflect.ValueOf(joinRec.DestinKey))
}
err := tx.Create(rec.Interface()).Error
if err != nil {
return errors.Wrap(err, "creating new relation records")
}
}
for destinKey, rec := range updates {
err := tx.Table(mtm.Scope.TableName()).
Where(
fmt.Sprintf(
`"%s" = ? and "%s" = ?`,
mtm.SourceDBName,
mtm.DestinDBName,
),
sourceKey,
destinKey,
).Updates(rec).Error
if err != nil {
return errors.Wrap(err, "updating related records")
}
}
}
return nil
func (c *Context) saveJoins(params *SaveParams, conn *sqlite.Conn, mtm *ManyToMany) error {
return errors.Errorf("stub")
// partial := false
// for _, pj := range params.PartialJoins {
// if mtm.JoinTable == gorm.ToDBName(pj) {
// partial = true
// }
// }
//
// joinType := reflect.PtrTo(mtm.Scope.GetModelStruct().ModelType)
//
// getDestinKey := func(v reflect.Value) interface{} {
// return v.Elem().FieldByName(mtm.DestinName).Interface()
// }
//
// for sourceKey, joinRecs := range mtm.Values {
// cacheAddr := reflect.New(reflect.SliceOf(joinType))
//
// err := conn.Where(
// fmt.Sprintf(`"%s" = ?`, mtm.SourceDBName),
// sourceKey,
// ).Find(cacheAddr.Interface()).Error
//
// if err != nil {
// return errors.Wrap(err, "fetching cached records to compare later")
// }
//
// cache := cacheAddr.Elem()
//
// cacheByDestinKey := make(map[interface{}]reflect.Value)
// for i := 0; i < cache.Len(); i++ {
// rec := cache.Index(i)
// cacheByDestinKey[getDestinKey(rec)] = rec
// }
//
// freshByDestinKey := make(map[interface{}]reflect.Value)
// for _, joinRec := range joinRecs {
// freshByDestinKey[joinRec.DestinKey] = joinRec.Record
// }
//
// var deletes []interface{}
// updates := make(map[interface{}]ChangedFields)
// var inserts []JoinRec
//
// // compare with cache: will result in delete or update
// for i := 0; i < cache.Len(); i++ {
// crec := cache.Index(i)
// destinKey := getDestinKey(crec)
// if frec, ok := freshByDestinKey[destinKey]; ok {
// if frec.IsValid() {
// // compare to maybe update
// ifrec := frec.Elem().Interface()
// icrec := crec.Elem().Interface()
//
// cf, err := DiffRecord(ifrec, icrec, mtm.Scope)
// if err != nil {
// return errors.Wrap(err, "diffing database records")
// }
//
// if cf != nil {
// updates[destinKey] = cf
// }
// }
// } else {
// deletes = append(deletes, destinKey)
// }
// }
//
// for _, joinRec := range joinRecs {
// if _, ok := cacheByDestinKey[joinRec.DestinKey]; !ok {
// inserts = append(inserts, joinRec)
// }
// }
//
// if partial {
// // Not deleting extra join records, as requested
// } else {
// if len(deletes) > 0 {
// // FIXME: this needs to be paginated to avoid hitting SQLite max variables
// rec := reflect.New(joinType.Elem())
// err := conn.
// Delete(
// rec.Interface(),
// fmt.Sprintf(
// `"%s" = ? and "%s" in (?)`,
// mtm.SourceDBName,
// mtm.DestinDBName,
// ),
// sourceKey,
// deletes,
// ).Error
// if err != nil {
// return errors.Wrap(err, "deleting extraneous relations")
// }
// }
// }
//
// for _, joinRec := range inserts {
// rec := joinRec.Record
//
// if !rec.IsValid() {
// // if not passed an explicit record, make it ourselves
// // that typically means the join table doesn't have additional
// // columns and is a simple many2many
// rec = reflect.New(joinType.Elem())
// rec.Elem().FieldByName(mtm.SourceName).Set(reflect.ValueOf(sourceKey))
// rec.Elem().FieldByName(mtm.DestinName).Set(reflect.ValueOf(joinRec.DestinKey))
// }
//
// err := conn.Create(rec.Interface()).Error
// if err != nil {
// return errors.Wrap(err, "creating new relation records")
// }
// }
//
// for destinKey, rec := range updates {
// err := conn.Table(mtm.Scope.TableName()).
// Where(
// fmt.Sprintf(
// `"%s" = ? and "%s" = ?`,
// mtm.SourceDBName,
// mtm.DestinDBName,
// ),
// sourceKey,
// destinKey,
// ).Updates(rec).Error
// if err != nil {
// return errors.Wrap(err, "updating related records")
// }
// }
// }
//
// return nil
}
This diff is collapsed.
package hades
import "reflect"
// Scope contain current operation's information when you perform any operation on the database
type Scope struct {
Value interface{}
ctx *Context
primaryKeyField *Field
skipLeft bool
fields *[]*Field
selectAttrs *[]string
}
// IndirectValue return scope's reflect value's indirect value
func (scope *Scope) IndirectValue() reflect.Value {
return indirect(reflect.ValueOf(scope.Value))
}
// New create a new Scope
func (scope *Scope) New(value interface{}) *Scope {
return &Scope{Value: value}
}
// Fields get value's fields
func (scope *Scope) Fields() []*Field {
if scope.fields == nil {
var (
fields []*Field
indirectScopeValue = scope.IndirectValue()
isStruct = indirectScopeValue.Kind() == reflect.Struct
)
for _, structField := range scope.GetModelStruct().StructFields {
if isStruct {
fieldValue := indirectScopeValue
for _, name := range structField.Names {
if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() {
fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
}
fieldValue = reflect.Indirect(fieldValue).FieldByName(name)
}
fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)})
} else {
fields = append(fields, &Field{StructField: structField, IsBlank: true})
}
}
scope.fields = &fields
}
return *scope.fields
}
// FieldByName find `gorm.Field` with field name or db name
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
var (
dbName = ToDBName(name)
mostMatchedField *Field
)
for _, field := range scope.Fields() {
if field.Name == name || field.DBName == name {
return field, true
}
if field.DBName == dbName {
mostMatchedField = field
}
}
return mostMatchedField, mostMatchedField != nil
}
// PrimaryFields return scope's primary fields
func (scope *Scope) PrimaryFields() (fields []*Field) {
for _, field := range scope.Fields() {
if field.IsPrimaryKey {
fields = append(fields, field)
}
}
return fields
}
// PrimaryField return scope's main primary field, if defined more that one primary fields, will return the one having column name `id` or the first one
func (scope *Scope) PrimaryField() *Field {
if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 {
if len(primaryFields) > 1 {
if field, ok := scope.FieldByName("id"); ok {
return field
}
}
return scope.PrimaryFields()[0]
}
return nil
}
// PrimaryKey get main primary field's db name
func (scope *Scope) PrimaryKey() string {
if field := scope.PrimaryField(); field != nil {
return field.DBName
}
return ""
}
// PrimaryKeyZero check main primary field's value is blank or not
func (scope *Scope) PrimaryKeyZero() bool {
field := scope.PrimaryField()
return field == nil || field.IsBlank
}
// PrimaryKeyValue get the primary key's value
func (scope *Scope) PrimaryKeyValue() interface{} {
if field := scope.PrimaryField(); field != nil && field.Field.IsValid() {
return field.Field.Interface()
}
return 0
}
// HasColumn to check if has column
func (scope *Scope) HasColumn(column string) bool {
for _, field := range scope.GetStructFields() {
if field.IsNormal && (field.Name == column || field.DBName == column) {
return true
}
}
return false
}
// TableName return table name
func (scope *Scope) TableName() string {
return scope.GetModelStruct().defaultTableName
}
// Err add error to Scope
func (scope *Scope) Err(err error) error {
if err != nil {
scope.ctx.AddError(err)
}
return err
}
package hades
import (
"reflect"
"github.com/pkg/errors"
)
type ScopeMap struct {
byType map[reflect.Type]*Scope
byDBName map[string]*Scope
}
func NewScopeMap() *ScopeMap {
return &ScopeMap{
byType: make(map[reflect.Type]*Scope),
byDBName: make(map[string]*Scope),
}
}
func (sm *ScopeMap) Add(c *Context, m interface{}) error {
reflectType := reflect.ValueOf(m).Type()
for reflectType.Kind() == reflect.Slice || reflectType.Kind() == reflect.Ptr {
reflectType = reflectType.Elem()
}
// what should we do if it's not a struct?
if reflectType.Kind() != reflect.Struct {
return errors.Errorf("hades expects all models to be structs, but got %v instead", reflectType)
}
s := c.NewScope(m)
sm.byType[reflectType] = s
sm.byDBName[s.TableName()] = s
return nil
}
func (sm *ScopeMap) ByDBName(dbname string) *Scope {
return sm.byDBName[dbname]
}
func (sm *ScopeMap) ByType(typ reflect.Type) *Scope {
return sm.byType[typ]
}
package hades
import (
"fmt"
"log"
"reflect"
"strings"
"crawshaw.io/sqlite"
"crawshaw.io/sqlite/sqliteutil"
)
func (c *Context) Select(conn *sqlite.Conn, result interface{}, where string, args ...interface{}) error {
s := c.NewScope(result)
var columns []string
ms := s.GetModelStruct()
for _, sf := range ms.StructFields {
columns = append(columns, sf.DBName)
}
query := fmt.Sprintf("select %s from %s where %s",
strings.Join(columns, ", "),
s.TableName(),
where)
log.Printf("query = %s", query)
args = append([]interface{}{}, args...)
resultVal := reflect.ValueOf(result).Elem()