diff --git a/associations/belongs_to_association.go b/associations/belongs_to_association.go index 2d70667c..a28e5d0a 100644 --- a/associations/belongs_to_association.go +++ b/associations/belongs_to_association.go @@ -31,7 +31,23 @@ func belongsToAssociationBuilder(p associationParams) (Association, error) { ownerVal := p.modelValue.FieldByName(p.field.Name) tags := p.popTags primaryIDField := defaults.String(tags.Find("primary_id").Value, "ID") - ownerIDField := defaults.String(tags.Find("fk_id").Value, fmt.Sprintf("%s%s", p.field.Name, "ID")) + ownerIDField := fmt.Sprintf("%s%s", p.field.Name, "ID") + + if tags.Find("fk_id").Value != "" { + dbTag := tags.Find("fk_id").Value + if _, found := p.modelType.FieldByName(dbTag); !found { + t := p.modelValue.Type() + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if f.Tag.Get("db") == dbTag { + ownerIDField = f.Name + break + } + } + } else { + ownerIDField = dbTag + } + } // belongs_to requires an holding field for the foreign model ID. if _, found := p.modelType.FieldByName(ownerIDField); !found { diff --git a/executors_test.go b/executors_test.go index 5af9304b..8b263c72 100644 --- a/executors_test.go +++ b/executors_test.go @@ -972,7 +972,7 @@ func Test_Eager_Create_Belongs_To(t *testing.T) { car := Taxi{ Model: "Fancy car", - Driver: User{ + Driver: &User{ Name: nulls.NewString("Larry 2"), }, } @@ -1101,7 +1101,7 @@ func Test_Flat_Create_Belongs_To(t *testing.T) { car := Taxi{ Model: "Fancy car", - Driver: user, + Driver: &user, } err = tx.Create(&car) diff --git a/pop_test.go b/pop_test.go index 6dd812af..20e4eda5 100644 --- a/pop_test.go +++ b/pop_test.go @@ -124,7 +124,7 @@ type Taxi struct { ID int `db:"id"` Model string `db:"model"` UserID nulls.Int `db:"user_id"` - Driver User `belongs_to:"user" fk_id:"UserID"` + Driver *User `belongs_to:"user" fk_id:"user_id"` CreatedAt time.Time `db:"created_at"` UpdatedAt time.Time `db:"updated_at"` } diff --git a/preload_associations.go b/preload_associations.go index 99236f2b..5d3be99a 100644 --- a/preload_associations.go +++ b/preload_associations.go @@ -270,11 +270,15 @@ func preloadHasMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInf asocValue := slice.Elem().Index(i) if mmi.mapper.FieldByName(mvalue, "ID").Interface() == mmi.mapper.FieldByName(asocValue, foreignField.Path).Interface() || reflect.DeepEqual(mmi.mapper.FieldByName(mvalue, "ID"), mmi.mapper.FieldByName(asocValue, foreignField.Path)) { - if modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array { + + switch { + case modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array: modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue)) - continue + case modelAssociationField.Kind() == reflect.Ptr: + modelAssociationField.Elem().Set(reflect.Append(modelAssociationField.Elem(), asocValue)) + default: + modelAssociationField.Set(asocValue) } - modelAssociationField.Set(asocValue) } } }) @@ -380,11 +384,15 @@ func preloadBelongsTo(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaI asocValue := slice.Elem().Index(i) if mmi.mapper.FieldByName(mvalue, fi.Path).Interface() == mmi.mapper.FieldByName(asocValue, "ID").Interface() || reflect.DeepEqual(mmi.mapper.FieldByName(mvalue, fi.Path), mmi.mapper.FieldByName(asocValue, "ID")) { - if modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array { + + switch { + case modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array: modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue)) - continue + case modelAssociationField.Kind() == reflect.Ptr: + modelAssociationField.Elem().Set(asocValue) + default: + modelAssociationField.Set(asocValue) } - modelAssociationField.Set(asocValue) } } })