Commit 82eb2f62 authored by Amos Wenger's avatar Amos Wenger

Panic if invalid struct tags are passed, many2many => many_to_many

parent cdadf38b
Pipeline #10985 passed with stage
in 32 seconds
......@@ -19,13 +19,13 @@ func Test_BelongsTo(t *testing.T) {
type Human struct {
ID int64
FateID int64
Fate *Fate `hades:"ignore"`
Fate *Fate
}
type Joke struct {
ID string
HumanID int64
Human *Human `hades:"ignore"`
Human *Human
}
models := []interface{}{&Human{}, &Fate{}, &Joke{}}
......
......@@ -2,7 +2,7 @@ package hades
import "reflect"
// JoinTableHandlerInterface is an interface for how to handle many2many relations
// JoinTableHandlerInterface is an interface for how to handle many_to_many relations
type JoinTableHandlerInterface interface {
// initialize join table handler
Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
......
......@@ -14,18 +14,18 @@ import (
type Language struct {
ID int64
Words []*Word `hades:"many2many:language_words"`
Words []*Word `hades:"many_to_many:language_words"`
}
type Word struct {
ID string
Comment string
Languages []*Language `hades:"many2many:language_words"`
Languages []*Language `hades:"many_to_many:language_words"`
}
type LanguageWord struct {
LanguageID int64 `hades:"primary_key;auto_increment:false"`
WordID string `hades:"primary_key;auto_increment:false"`
LanguageID int64 `hades:"primary_key"`
WordID string `hades:"primary_key"`
}
func Test_ManyToMany(t *testing.T) {
......@@ -135,10 +135,10 @@ type Game struct {
}
type ProfileGame struct {
ProfileID int64 `hades:"primary_key;auto_increment:false"`
ProfileID int64 `hades:"primary_key"`
Profile *Profile
GameID int64 `hades:"primary_key;auto_increment:false"`
GameID int64 `hades:"primary_key"`
Game *Game
Order int64
......@@ -185,13 +185,13 @@ func Test_ManyToManyRevenge(t *testing.T) {
type Piece struct {
ID int64
Authors []*Author `hades:"many2many:piece_authors"`
Authors []*Author `hades:"many_to_many:piece_authors"`
}
type Author struct {
ID int64
Name string
Pieces []*Piece `hades:"many2many:piece_authors"`
Pieces []*Piece `hades:"many_to_many:piece_authors"`
}
type PieceAuthor struct {
......
package hades
import (
"fmt"
"go/ast"
"reflect"
"strings"
......@@ -55,7 +56,7 @@ type StructField struct {
IsSquashed bool
SquashedFields []*StructField
Tag reflect.StructTag
TagSettings map[string]string
TagSettings map[TagSetting]string
Struct reflect.StructField
IsForeignKey bool
Relationship *Relationship
......@@ -84,6 +85,30 @@ func getForeignField(column string, fields []*StructField) *StructField {
return nil
}
type TagSetting string
const (
TagSettingIgnore TagSetting = "-"
TagSettingSquash TagSetting = "squash"
TagSettingManyToMany TagSetting = "many_to_many"
TagSettingPrimaryKey TagSetting = "primary_key"
TagSettingForeignKey TagSetting = "foreign_key"
TagSettingAssociationForeignKey TagSetting = "association_foreign_key"
TagSettingJoinTableForeignKey TagSetting = "join_table_foreign_key"
TagSettingAssociationJoinTableForeignKey TagSetting = "association_join_table_foreign_key"
)
var ValidTagSettings = map[TagSetting]bool{
TagSettingIgnore: true,
TagSettingSquash: true,
TagSettingManyToMany: true,
TagSettingPrimaryKey: true,
TagSettingForeignKey: true,
TagSettingAssociationForeignKey: true,
TagSettingJoinTableForeignKey: true,
TagSettingAssociationJoinTableForeignKey: true,
}
// GetModelStruct get value's model struct, relationships based on struct and tag definition
func (scope *Scope) GetModelStruct() *ModelStruct {
var modelStruct ModelStruct
......@@ -119,20 +144,20 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
Name: fieldStruct.Name,
Names: []string{fieldStruct.Name},
Tag: fieldStruct.Tag,
TagSettings: parseTagSetting(fieldStruct.Tag),
TagSettings: parseTagSetting(reflectType, fieldStruct.Name, fieldStruct.Tag),
}
// is ignored field
if _, ok := field.TagSettings["-"]; ok {
if _, ok := field.TagSettings[TagSettingIgnore]; ok {
field.IsIgnored = true
} else if _, ok := field.TagSettings["SQUASH"]; ok {
} else if _, ok := field.TagSettings[TagSettingSquash]; ok {
field.IsSquashed = true
nestedModelStruct := scope.ctx.NewScope(reflect.Zero(field.Struct.Type).Interface()).GetModelStruct()
for _, sf := range nestedModelStruct.StructFields {
field.SquashedFields = append(field.SquashedFields, sf)
}
} else {
if _, ok := field.TagSettings["PRIMARY_KEY"]; ok {
if _, ok := field.TagSettings[TagSettingPrimaryKey]; ok {
field.IsPrimaryKey = true
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
}
......@@ -159,13 +184,11 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
elemType = field.Struct.Type
)
if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" {
if foreignKey := field.TagSettings[TagSettingForeignKey]; foreignKey != "" {
foreignKeys = strings.Split(foreignKey, ",")
}
if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" {
associationForeignKeys = strings.Split(foreignKey, ",")
} else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
if foreignKey := field.TagSettings[TagSettingAssociationForeignKey]; foreignKey != "" {
associationForeignKeys = strings.Split(foreignKey, ",")
}
......@@ -174,13 +197,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
}
if elemType.Kind() == reflect.Struct {
if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
if manyToMany := field.TagSettings[TagSettingManyToMany]; manyToMany != "" {
relationship.Kind = "many_to_many"
{ // Foreign Keys for Source
joinTableDBNames := []string{}
if foreignKey := field.TagSettings["JOINTABLE_FOREIGNKEY"]; foreignKey != "" {
if foreignKey := field.TagSettings[TagSettingAssociationJoinTableForeignKey]; foreignKey != "" {
joinTableDBNames = strings.Split(foreignKey, ",")
}
......@@ -211,7 +234,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
{ // Foreign Keys for Association (Destination)
associationJoinTableDBNames := []string{}
if foreignKey := field.TagSettings["ASSOCIATION_JOINTABLE_FOREIGNKEY"]; foreignKey != "" {
if foreignKey := field.TagSettings[TagSettingAssociationJoinTableForeignKey]; foreignKey != "" {
associationJoinTableDBNames = strings.Split(foreignKey, ",")
}
......@@ -240,7 +263,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
}
joinTableHandler := JoinTableHandler{}
joinTableHandler.Setup(relationship, many2many, reflectType, elemType)
joinTableHandler.Setup(relationship, manyToMany, reflectType, elemType)
relationship.JoinTableHandler = &joinTableHandler
field.Relationship = relationship
} else {
......@@ -322,13 +345,11 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
tagAssociationForeignKeys []string
)
if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" {
if foreignKey := field.TagSettings[TagSettingForeignKey]; foreignKey != "" {
tagForeignKeys = strings.Split(foreignKey, ",")
}
if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" {
tagAssociationForeignKeys = strings.Split(foreignKey, ",")
} else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
if foreignKey := field.TagSettings[TagSettingAssociationForeignKey]; foreignKey != "" {
tagAssociationForeignKeys = strings.Split(foreignKey, ",")
}
......@@ -488,17 +509,31 @@ func (scope *Scope) GetStructFields() (fields []*StructField) {
return scope.GetModelStruct().StructFields
}
func parseTagSetting(tags reflect.StructTag) map[string]string {
setting := map[string]string{}
for _, str := range []string{tags.Get("hades")} {
func parseTagSetting(reflectType reflect.Type, fieldName string, tags reflect.StructTag) map[TagSetting]string {
setting := map[TagSetting]string{}
if str, ok := tags.Lookup("hades"); ok {
tags := strings.Split(str, ";")
for _, value := range tags {
v := strings.Split(value, ":")
k := strings.TrimSpace(strings.ToUpper(v[0]))
k := strings.TrimSpace(v[0])
if _, ok := ValidTagSettings[TagSetting(k)]; !ok {
var validTags []string
for vt := range ValidTagSettings {
validTags = append(validTags, string(vt))
}
panic(fmt.Sprintf("invalid tag setting %q for field %s of type %v - valid tag settings are %s",
k,
fieldName,
reflectType,
strings.Join(validTags, ", "),
))
}
if len(v) >= 2 {
setting[k] = strings.Join(v[1:], ":")
setting[TagSetting(k)] = strings.Join(v[1:], ":")
} else {
setting[k] = k
setting[TagSetting(k)] = k
}
}
}
......
......@@ -100,7 +100,7 @@ func (c *Context) saveJoins(params *SaveParams, conn *sqlite.Conn, mtm *ManyToMa
} else {
// if not passed an explicit record, make it ourselves
// that typically means the join table doesn't have additional
// columns and is a simple many2many
// columns and is a simple many_to_many
eq := builder.Eq{
mtm.SourceDBName: sourceKey,
mtm.DestinDBName: joinRec.DestinKey,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment