diff --git a/bind.go b/bind.go index 6d4d18ad9f..760d3cab83 100644 --- a/bind.go +++ b/bind.go @@ -41,6 +41,7 @@ type fieldTextDecoder struct { fieldName string tag string // query,param,header,respHeader ... reqField string + et reflect.Type dec bind.TextDecoder get func(c Ctx, key string, defaultValue ...string) string } @@ -52,17 +53,26 @@ func (d *fieldTextDecoder) Decode(ctx Ctx, reqValue reflect.Value) error { } var err error - if len(d.parentIndex) > 0 { - for _, i := range d.parentIndex { - reqValue = reqValue.Field(i) + for _, i := range d.parentIndex { + reqValue = reqValue.Field(i) + } + + // Pointer support for struct elems + field := reqValue.Field(d.index) + if field.Kind() == reflect.Ptr { + elem := reflect.New(d.et) + err = d.dec.UnmarshalString(text, elem.Elem()) + if err != nil { + return fmt.Errorf("unable to decode '%s' as %s: %w", text, d.reqField, err) } - err = d.dec.UnmarshalString(text, reqValue.Field(d.index)) + field.Set(elem) - } else { - err = d.dec.UnmarshalString(text, reqValue.Field(d.index)) + return nil } + // Non-pointer elems + err = d.dec.UnmarshalString(text, field) if err != nil { return fmt.Errorf("unable to decode '%s' as %s: %w", text, d.reqField, err) } diff --git a/bind_test.go b/bind_test.go index ae135ec83f..b94ba4d5af 100644 --- a/bind_test.go +++ b/bind_test.go @@ -25,7 +25,7 @@ func Test_Binder(t *testing.T) { ctx.Request().Header.Set("content-type", "application/json") var req struct { - ID string `param:"id"` + ID *string `param:"id"` } var body struct { @@ -34,7 +34,7 @@ func Test_Binder(t *testing.T) { err := ctx.Bind().Req(&req).JSON(&body).Err() require.NoError(t, err) - require.Equal(t, "id string", req.ID) + require.Equal(t, "id string", *req.ID) require.Equal(t, "john doe", body.Name) } @@ -47,11 +47,12 @@ func Test_Binder_Nested(t *testing.T) { c.Request().Header.SetContentType("") c.Request().URI().SetQueryString("name=tom&nested.and.age=10&nested.and.test=john") + // TODO: pointer support for structs var req struct { Name string `query:"name"` Nested struct { And struct { - Age int `query:"age"` + Age *int `query:"age"` Test string `query:"test"` } `query:"and"` } `query:"nested"` @@ -61,7 +62,7 @@ func Test_Binder_Nested(t *testing.T) { require.NoError(t, err) require.Equal(t, "tom", req.Name) require.Equal(t, "john", req.Nested.And.Test) - require.Equal(t, 10, req.Nested.And.Age) + require.Equal(t, 10, *req.Nested.And.Age) } func Test_Binder_Nested_Slice(t *testing.T) { @@ -91,6 +92,33 @@ func Test_Binder_Nested_Slice(t *testing.T) { require.Equal(t, "tom", req.Name) } +/*func Test_Binder_Nested_Deeper_Slice(t *testing.T) { + t.Parallel() + app := New() + + c := app.NewCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) + c.Request().SetBody([]byte(``)) + c.Request().Header.SetContentType("") + c.Request().URI().SetQueryString("data[0][users][0][name]=john&data[0][users][0][age]=10&data[1][users][0][name]=doe&data[1][users][0][age]=12") + + var req struct { + Data []struct { + Users []struct { + Name string `query:"name"` + Age int `query:"age"` + } `query:"subData"` + } `query:"data"` + } + + err := c.Bind().Req(&req).Err() + require.NoError(t, err) + require.Equal(t, 2, len(req.Data)) + require.Equal(t, "john", req.Data[0].Users[0].Name) + require.Equal(t, 10, req.Data[0].Users[0].Age) + require.Equal(t, "doe", req.Data[1].Users[0].Name) + require.Equal(t, 12, req.Data[1].Users[0].Age) +}*/ + // go test -run Test_Bind_BasicType -v func Test_Bind_BasicType(t *testing.T) { t.Parallel() diff --git a/binder_compile.go b/binder_compile.go index 4e0331a014..3b091389e9 100644 --- a/binder_compile.go +++ b/binder_compile.go @@ -113,6 +113,9 @@ func compileFieldDecoder(field reflect.StructField, index int, opt bindCompileOp } // Nested binding support + if field.Type.Kind() == reflect.Ptr { + field.Type = field.Type.Elem() + } if field.Type.Kind() == reflect.Struct { var decoders []decoder el := field.Type @@ -185,7 +188,12 @@ func compileTextBasedDecoder(field reflect.StructField, index int, tagScope, tag return nil, errors.New("unexpected tag scope " + strconv.Quote(tagScope)) } - textDecoder, err := bind.CompileTextDecoder(field.Type) + et := field.Type + if field.Type.Kind() == reflect.Ptr { + et = field.Type.Elem() + } + + textDecoder, err := bind.CompileTextDecoder(et) if err != nil { return nil, err } @@ -197,6 +205,7 @@ func compileTextBasedDecoder(field reflect.StructField, index int, tagScope, tag reqField: tagContent, dec: textDecoder, get: get, + et: et, } if len(parentIndex) > 0 { @@ -206,11 +215,13 @@ func compileTextBasedDecoder(field reflect.StructField, index int, tagScope, tag return []decoder{fieldDecoder}, nil } +// TODO type subElem struct { et reflect.Type tag string index int elementDecoder bind.TextDecoder + //subElems []subElem } func compileSliceFieldTextBasedDecoder(field reflect.StructField, index int, tagScope string, tagContent string) ([]decoder, error) {