Skip to content

Commit

Permalink
Add support for contextual validation
Browse files Browse the repository at this point in the history
By passing the Gin context to bindings, custom validators can take
advantage of the information in the context.
  • Loading branch information
kszafran committed Sep 22, 2021
1 parent e73cffe commit c790e2a
Show file tree
Hide file tree
Showing 15 changed files with 284 additions and 71 deletions.
45 changes: 40 additions & 5 deletions binding/binding.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

package binding

import "net/http"
import (
"context"
"net/http"
)

// Content-Type MIME of the most common data formats.
const (
Expand All @@ -32,20 +35,41 @@ type Binding interface {
Bind(*http.Request, interface{}) error
}

// BindingBody adds BindBody method to Binding. BindBody is similar with Bind,
// CtxBinding enables contextual validation by adding CtxBind to Binding.
// Custom validators can take advantage of the information in the context.
type CtxBinding interface {
Binding
CtxBind(context.Context, *http.Request, interface{}) error
}

// BindingBody adds BindBody method to Binding. BindBody is similar to Bind,
// but it reads the body from supplied bytes instead of req.Body.
type BindingBody interface {
Binding
BindBody([]byte, interface{}) error
}

// BindingUri adds BindUri method to Binding. BindUri is similar with Bind,
// but it read the Params.
// CtxBindingBody enables contextual validation by adding CtxBindBody to BindingBody.
// Custom validators can take advantage of the information in the context.
type CtxBindingBody interface {
BindingBody
CtxBind(context.Context, *http.Request, interface{}) error
CtxBindBody(context.Context, []byte, interface{}) error
}

// BindingUri is similar to Bind, but it read the Params.
type BindingUri interface {
Name() string
BindUri(map[string][]string, interface{}) error
}

// CtxBindingUri enables contextual validation by adding CtxBindUri to BindingUri.
// Custom validators can take advantage of the information in the context.
type CtxBindingUri interface {
BindingUri
CtxBindUri(context.Context, map[string][]string, interface{}) error
}

// StructValidator is the minimal interface which needs to be implemented in
// order for it to be used as the validator engine for ensuring the correctness
// of the request. Gin provides a default implementation for this using
Expand All @@ -64,6 +88,14 @@ type StructValidator interface {
Engine() interface{}
}

// CtxStructValidator is an extension of StructValidator that requires implementing
// context-aware validation.
// Custom validators can take advantage of the information in the context.
type CtxStructValidator interface {
StructValidator
ValidateStructCtx(context.Context, interface{}) error
}

// Validator is the default validator which implements the StructValidator
// interface. It uses https://github.com/go-playground/validator/tree/v10.6.1
// under the hood.
Expand Down Expand Up @@ -110,9 +142,12 @@ func Default(method, contentType string) Binding {
}
}

func validate(obj interface{}) error {
func validateCtx(ctx context.Context, obj interface{}) error {
if Validator == nil {
return nil
}
if v, ok := Validator.(CtxStructValidator); ok {
return v.ValidateStructCtx(ctx, obj)
}
return Validator.ValidateStruct(obj)
}
15 changes: 13 additions & 2 deletions binding/binding_msgpack_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package binding

import (
"bytes"
"context"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -35,7 +36,7 @@ func TestBindingMsgPack(t *testing.T) {
string(data), string(data[1:]))
}

func testMsgPackBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody string) {
func testMsgPackBodyBinding(t *testing.T, b CtxBinding, name, path, badPath, body, badBody string) {
assert.Equal(t, name, b.Name())

obj := FooStruct{}
Expand All @@ -48,7 +49,17 @@ func testMsgPackBodyBinding(t *testing.T, b Binding, name, path, badPath, body,
obj = FooStruct{}
req = requestWithBody("POST", badPath, badBody)
req.Header.Add("Content-Type", MIMEMSGPACK)
err = MsgPack.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)

obj2 := ConditionalFooStruct{}
req = requestWithBody("POST", path, body)
req.Header.Add("Content-Type", MIMEMSGPACK)
err = b.CtxBind(context.Background(), req, &obj2)
assert.NoError(t, err)
assert.Equal(t, "bar", obj2.Foo)

err = b.CtxBind(context.WithValue(context.Background(), "condition", true), req, &obj2) // nolint
assert.Error(t, err)
}

Expand Down
47 changes: 41 additions & 6 deletions binding/binding_nomsgpack.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

package binding

import "net/http"
import (
"context"
"net/http"
)

// Content-Type MIME of the most common data formats.
const (
Expand All @@ -30,20 +33,41 @@ type Binding interface {
Bind(*http.Request, interface{}) error
}

// BindingBody adds BindBody method to Binding. BindBody is similar with Bind,
// CtxBinding enables contextual validation by adding CtxBind to Binding.
// Custom validators can take advantage of the information in the context.
type CtxBinding interface {
Binding
CtxBind(context.Context, *http.Request, interface{}) error
}

// BindingBody adds BindBody method to Binding. BindBody is similar to Bind,
// but it reads the body from supplied bytes instead of req.Body.
type BindingBody interface {
Binding
BindBody([]byte, interface{}) error
}

// BindingUri adds BindUri method to Binding. BindUri is similar with Bind,
// but it read the Params.
// CtxBindingBody enables contextual validation by adding CtxBindBody to BindingBody.
// Custom validators can take advantage of the information in the context.
type CtxBindingBody interface {
BindingBody
CtxBind(context.Context, *http.Request, interface{}) error
CtxBindBody(context.Context, []byte, interface{}) error
}

// BindingUri is similar to Bind, but it read the Params.
type BindingUri interface {
Name() string
BindUri(map[string][]string, interface{}) error
}

// CtxBindingUri enables contextual validation by adding CtxBindUri to BindingUri.
// Custom validators can take advantage of the information in the context.
type CtxBindingUri interface {
BindingUri
CtxBindUri(context.Context, map[string][]string, interface{}) error
}

// StructValidator is the minimal interface which needs to be implemented in
// order for it to be used as the validator engine for ensuring the correctness
// of the request. Gin provides a default implementation for this using
Expand All @@ -62,6 +86,14 @@ type StructValidator interface {
Engine() interface{}
}

// CtxStructValidator is an extension of StructValidator that requires implementing
// context-aware validation.
// Custom validators can take advantage of the information in the context.
type CtxStructValidator interface {
StructValidator
ValidateStructCtx(context.Context, interface{}) error
}

// Validator is the default validator which implements the StructValidator
// interface. It uses https://github.com/go-playground/validator/tree/v10.6.1
// under the hood.
Expand All @@ -85,7 +117,7 @@ var (
// Default returns the appropriate Binding instance based on the HTTP method
// and the content type.
func Default(method, contentType string) Binding {
if method == "GET" {
if method == http.MethodGet {
return Form
}

Expand All @@ -105,9 +137,12 @@ func Default(method, contentType string) Binding {
}
}

func validate(obj interface{}) error {
func validateCtx(ctx context.Context, obj interface{}) error {
if Validator == nil {
return nil
}
if v, ok := Validator.(CtxStructValidator); ok {
return v.ValidateStructCtx(ctx, obj)
}
return Validator.ValidateStruct(obj)
}
75 changes: 66 additions & 9 deletions binding/binding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package binding

import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
Expand All @@ -20,6 +21,7 @@ import (
"time"

"github.com/gin-gonic/gin/testdata/protoexample"
"github.com/go-playground/validator/v10"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
)
Expand All @@ -38,6 +40,10 @@ type FooStruct struct {
Foo string `msgpack:"foo" json:"foo" form:"foo" xml:"foo" binding:"required,max=32"`
}

type ConditionalFooStruct struct {
Foo string `msgpack:"foo" json:"foo" form:"foo" xml:"foo" binding:"required_if_condition,max=32"`
}

type FooBarStruct struct {
FooStruct
Bar string `msgpack:"bar" json:"bar" form:"bar" xml:"bar" binding:"required"`
Expand Down Expand Up @@ -144,6 +150,16 @@ type FooStructForMapPtrType struct {
PtrBar *map[string]interface{} `form:"ptr_bar"`
}

func init() {
_ = Validator.Engine().(*validator.Validate).RegisterValidationCtx(
"required_if_condition", func(ctx context.Context, fl validator.FieldLevel) bool {
if ctx.Value("condition") == true {
return !fl.Field().IsZero()
}
return true
})
}

func TestBindingDefault(t *testing.T) {
assert.Equal(t, Form, Default("GET", ""))
assert.Equal(t, Form, Default("GET", MIMEJSON))
Expand Down Expand Up @@ -796,6 +812,38 @@ func TestUriBinding(t *testing.T) {
assert.Equal(t, map[string]interface{}(nil), not.Name)
}

func TestUriBindingWithCtx(t *testing.T) {
b := Uri

type Tag struct {
Name string `uri:"name" binding:"required_if_condition"`
}

empty := make(map[string][]string)
assert.NoError(t, b.CtxBindUri(context.Background(), empty, new(Tag)))
assert.Error(t, b.CtxBindUri(context.WithValue(context.Background(), "condition", true), empty, new(Tag))) // nolint
}

func TestUriBindingWithNonCtxValidator(t *testing.T) {
prev := Validator
defer func() {
Validator = prev
}()
Validator = &nonCtxValidator{}

TestUriBinding(t)
}

type nonCtxValidator defaultValidator

func (v *nonCtxValidator) ValidateStruct(obj interface{}) error {
return (*defaultValidator)(v).ValidateStruct(obj)
}

func (v *nonCtxValidator) Engine() interface{} {
return (*defaultValidator)(v).Engine()
}

func TestUriInnerBinding(t *testing.T) {
type Tag struct {
Name string `uri:"name"`
Expand Down Expand Up @@ -1179,7 +1227,7 @@ func testQueryBindingBoolFail(t *testing.T, method, path, badPath, body, badBody
assert.Error(t, err)
}

func testBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody string) {
func testBodyBinding(t *testing.T, b CtxBinding, name, path, badPath, body, badBody string) {
assert.Equal(t, name, b.Name())

obj := FooStruct{}
Expand All @@ -1190,7 +1238,16 @@ func testBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody

obj = FooStruct{}
req = requestWithBody("POST", badPath, badBody)
err = JSON.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)

obj2 := ConditionalFooStruct{}
req = requestWithBody("POST", path, body)
err = b.CtxBind(context.Background(), req, &obj2)
assert.NoError(t, err)
assert.Equal(t, "bar", obj2.Foo)

err = b.CtxBind(context.WithValue(context.Background(), "condition", true), req, &obj2) // nolint
assert.Error(t, err)
}

Expand All @@ -1204,7 +1261,7 @@ func testBodyBindingSlice(t *testing.T, b Binding, name, path, badPath, body, ba

var obj2 []FooStruct
req = requestWithBody("POST", badPath, badBody)
err = JSON.Bind(req, &obj2)
err = b.Bind(req, &obj2)
assert.Error(t, err)
}

Expand Down Expand Up @@ -1249,7 +1306,7 @@ func testBodyBindingUseNumber(t *testing.T, b Binding, name, path, badPath, body

obj = FooStructUseNumber{}
req = requestWithBody("POST", badPath, badBody)
err = JSON.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)
}

Expand All @@ -1267,7 +1324,7 @@ func testBodyBindingUseNumber2(t *testing.T, b Binding, name, path, badPath, bod

obj = FooStructUseNumber{}
req = requestWithBody("POST", badPath, badBody)
err = JSON.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)
}

Expand All @@ -1285,7 +1342,7 @@ func testBodyBindingDisallowUnknownFields(t *testing.T, b Binding, path, badPath

obj = FooStructDisallowUnknownFields{}
req = requestWithBody("POST", badPath, badBody)
err = JSON.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)
assert.Contains(t, err.Error(), "what")
}
Expand All @@ -1301,7 +1358,7 @@ func testBodyBindingFail(t *testing.T, b Binding, name, path, badPath, body, bad

obj = FooStruct{}
req = requestWithBody("POST", badPath, badBody)
err = JSON.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)
}

Expand All @@ -1318,7 +1375,7 @@ func testProtoBodyBinding(t *testing.T, b Binding, name, path, badPath, body, ba
obj = protoexample.Test{}
req = requestWithBody("POST", badPath, badBody)
req.Header.Add("Content-Type", MIMEPROTOBUF)
err = ProtoBuf.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)
}

Expand Down Expand Up @@ -1349,7 +1406,7 @@ func testProtoBodyBindingFail(t *testing.T, b Binding, name, path, badPath, body
obj = protoexample.Test{}
req = requestWithBody("POST", badPath, badBody)
req.Header.Add("Content-Type", MIMEPROTOBUF)
err = ProtoBuf.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)
}

Expand Down
Loading

0 comments on commit c790e2a

Please sign in to comment.