Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support non-strict mode to decode object to map when unregistered #309

Merged
merged 7 commits into from
Jan 18, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,23 @@ type Decoder struct {
refs []interface{}
// record type refs, both list and map need it
typeRefs *TypeRefs
classInfoList []*classInfo
classInfoList []*ClassInfo
isSkip bool

// In strict mode, a class data can be decoded only when the class is registered, otherwise error returned.
// In non-strict mode, a class data will be decoded to a map when the class is not registered.
// The default is non-strict mode, user can change it as required.
Strict bool
}

// FindClassInfo find ClassInfo for the given name in decoder class info list.
func (d *Decoder) FindClassInfo(javaName string) *ClassInfo {
for _, info := range d.classInfoList {
if info.javaName == javaName {
return info
}
}
return nil
}

// Error part
Expand All @@ -49,6 +64,16 @@ func NewDecoder(b []byte) *Decoder {
return &Decoder{reader: bufio.NewReader(bytes.NewReader(b)), typeRefs: &TypeRefs{records: map[string]bool{}}}
}

// NewStrictDecoder generate a strict mode decoder instance.
// In strict mode, all target class must be registered.
wongoo marked this conversation as resolved.
Show resolved Hide resolved
func NewStrictDecoder(b []byte) *Decoder {
return &Decoder{
reader: bufio.NewReader(bytes.NewReader(b)),
typeRefs: &TypeRefs{records: map[string]bool{}},
Strict: true,
}
}

// NewDecoderSize generate a decoder instance.
func NewDecoderSize(b []byte, size int) *Decoder {
return &Decoder{reader: bufio.NewReaderSize(bytes.NewReader(b), size), typeRefs: &TypeRefs{records: map[string]bool{}}}
Expand Down
12 changes: 11 additions & 1 deletion encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,21 @@ import (

// Encoder struct
type Encoder struct {
classInfoList []*classInfo
classInfoList []*ClassInfo
buffer []byte
refMap map[unsafe.Pointer]_refElem
}

// classIndex find the index of the given java name in encoder class info list.
func (e *Encoder) classIndex(javaName string) int {
for i := range e.classInfoList {
if javaName == e.classInfoList[i].javaName {
return i
}
}
return -1
}

// NewEncoder generate an encoder instance
func NewEncoder() *Encoder {
buffer := make([]byte, 64)
Expand Down
18 changes: 9 additions & 9 deletions hessian.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func NewHessianCodec(reader *bufio.Reader) *HessianCodec {
}
}

// NewHessianCodec generate a new hessian codec instance
// NewHessianCodecCustom generate a new hessian codec instance.
func NewHessianCodecCustom(pkgType PackageType, reader *bufio.Reader, bodyLen int) *HessianCodec {
return &HessianCodec{
pkgType: pkgType,
Expand Down Expand Up @@ -129,15 +129,15 @@ func (h *HessianCodec) ReadHeader(header *DubboHeader) error {
return perrors.Errorf("serialization ID:%v", header.SerialID)
}

flag := buf[2] & FLAG_EVENT
if flag != Zero {
headerFlag := buf[2] & FLAG_EVENT
if headerFlag != Zero {
header.Type |= PackageHeartbeat
}
flag = buf[2] & FLAG_REQUEST
if flag != Zero {
headerFlag = buf[2] & FLAG_REQUEST
if headerFlag != Zero {
header.Type |= PackageRequest
flag = buf[2] & FLAG_TWOWAY
if flag != Zero {
headerFlag = buf[2] & FLAG_TWOWAY
if headerFlag != Zero {
header.Type |= PackageRequest_TwoWay
}
} else {
Expand Down Expand Up @@ -197,7 +197,7 @@ func (h *HessianCodec) ReadBody(rspObj interface{}) error {
case PackageRequest | PackageHeartbeat, PackageResponse | PackageHeartbeat:
case PackageRequest:
if rspObj != nil {
if err = unpackRequestBody(NewDecoder(buf[:]), rspObj); err != nil {
if err = unpackRequestBody(NewStrictDecoder(buf[:]), rspObj); err != nil {
return perrors.WithStack(err)
}
}
Expand All @@ -212,7 +212,7 @@ func (h *HessianCodec) ReadBody(rspObj interface{}) error {
return nil
}

// ignore body, but only read attachments
// ReadAttachments ignore body, but only read attachments
func (h *HessianCodec) ReadAttachments() (map[string]string, error) {
if h.reader.Buffered() < h.bodyLen {
return nil, ErrBodyNotEnough
Expand Down
52 changes: 33 additions & 19 deletions hessian_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,39 +114,52 @@ func doTestResponse(t *testing.T, packageType PackageType, responseStatus byte,
func TestResponse(t *testing.T) {
caseObj := Case{A: "a", B: 1}
decodedResponse := &Response{}
RegisterPOJO(&caseObj)

arr := []*Case{&caseObj}
var arrRes []interface{}
decodedResponse.RspObj = &arrRes
decodedResponse.RspObj = nil
doTestResponse(t, PackageResponse, Response_OK, arr, decodedResponse, func() {
arrRes, ok := decodedResponse.RspObj.([]*Case)
if !ok {
t.Errorf("expect []*Case, but get %s", reflect.TypeOf(decodedResponse.RspObj).String())
return
}
assert.Equal(t, 1, len(arrRes))
assert.Equal(t, &caseObj, arrRes[0])
})

decodedResponse.RspObj = &Case{}
doTestResponse(t, PackageResponse, Response_OK, &Case{A: "a", B: 1}, decodedResponse, nil)
doTestResponse(t, PackageResponse, Response_OK, &caseObj, decodedResponse, func() {
assert.Equal(t, &caseObj, decodedResponse.RspObj)
})

s := "ok!!!!!"
strObj := ""
decodedResponse.RspObj = &strObj
doTestResponse(t, PackageResponse, Response_OK, s, decodedResponse, nil)
doTestResponse(t, PackageResponse, Response_OK, s, decodedResponse, func() {
assert.Equal(t, s, decodedResponse.RspObj)
})

var intObj int64
decodedResponse.RspObj = &intObj
doTestResponse(t, PackageResponse, Response_OK, int64(3), decodedResponse, nil)
doTestResponse(t, PackageResponse, Response_OK, int64(3), decodedResponse, func() {
assert.Equal(t, int64(3), decodedResponse.RspObj)
})

boolObj := false
decodedResponse.RspObj = &boolObj
doTestResponse(t, PackageResponse, Response_OK, true, decodedResponse, nil)
doTestResponse(t, PackageResponse, Response_OK, true, decodedResponse, func() {
assert.Equal(t, true, decodedResponse.RspObj)
})

strObj = ""
decodedResponse.RspObj = &strObj
doTestResponse(t, PackageResponse, Response_SERVER_ERROR, "error!!!!!", decodedResponse, nil)
errorMsg := "error!!!!!"
decodedResponse.RspObj = nil
doTestResponse(t, PackageResponse, Response_SERVER_ERROR, errorMsg, decodedResponse, func() {
assert.Equal(t, "java exception:error!!!!!", decodedResponse.Exception.Error())
})

decodedResponse.RspObj = nil
decodedResponse.Exception = nil
mapObj := map[string][]*Case{"key": {&caseObj}}
mapRes := map[interface{}]interface{}{}
decodedResponse.RspObj = &mapRes
doTestResponse(t, PackageResponse, Response_OK, mapObj, decodedResponse, func() {
mapRes, ok := decodedResponse.RspObj.(map[interface{}]interface{})
if !ok {
t.Errorf("expect map[string][]*Case, but get %s", reflect.TypeOf(decodedResponse.RspObj).String())
return
}
c, ok := mapRes["key"]
if !ok {
assert.FailNow(t, "no key in decoded response map")
Expand Down Expand Up @@ -211,7 +224,8 @@ func TestHessianCodec_ReadAttachments(t *testing.T) {
t.Log(h)

err = codecR1.ReadBody(body)
assert.Equal(t, "can not find go type name com.test.caseb in registry", err.Error())
assert.NoError(t, err)
// assert.Equal(t, "can not find go type name com.test.caseb in registry", err.Error())
attrs, err := codecR2.ReadAttachments()
assert.NoError(t, err)
assert.Equal(t, "2.6.4", attrs[DUBBO_VERSION_KEY])
Expand Down
2 changes: 1 addition & 1 deletion int.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (d *Decoder) decInt32(flag int32) (int32, error) {
}
}

func (d *Encoder) encTypeInt32(b []byte, p interface{}) ([]byte, error) {
func (e *Encoder) encTypeInt32(b []byte, p interface{}) ([]byte, error) {
value := reflect.ValueOf(p)
if PackPtr(value).IsNil() {
return EncNull(b), nil
Expand Down
2 changes: 1 addition & 1 deletion java_collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (JavaCollectionSerializer) EncObject(e *Encoder, vv POJO) error {
return nil
}

func (JavaCollectionSerializer) DecObject(d *Decoder, typ reflect.Type, cls *classInfo) (interface{}, error) {
func (JavaCollectionSerializer) DecObject(d *Decoder, typ reflect.Type, cls *ClassInfo) (interface{}, error) {
// for the java impl of hessian encode collections as list, which will not be decoded as object in go impl, this method should not be called
return nil, perrors.New("unexpected collection decode call")
}
Expand Down
14 changes: 4 additions & 10 deletions java_sql_time.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,9 @@ type JavaSqlTimeSerializer struct{}
// nolint
func (JavaSqlTimeSerializer) EncObject(e *Encoder, vv POJO) error {
var (
i int
idx int
err error
clsDef *classInfo
clsDef *ClassInfo
className string
ptrV reflect.Value
)
Expand All @@ -78,13 +77,8 @@ func (JavaSqlTimeSerializer) EncObject(e *Encoder, vv POJO) error {
}

// write object definition
idx = -1
for i = range e.classInfoList {
if v.JavaClassName() == e.classInfoList[i].javaName {
idx = i
break
}
}
idx = e.classIndex(v.JavaClassName())

if idx == -1 {
idx, ok = checkPOJORegistry(vv)
if !ok {
Expand Down Expand Up @@ -114,7 +108,7 @@ func (JavaSqlTimeSerializer) EncObject(e *Encoder, vv POJO) error {
}

// nolint
func (JavaSqlTimeSerializer) DecObject(d *Decoder, typ reflect.Type, cls *classInfo) (interface{}, error) {
func (JavaSqlTimeSerializer) DecObject(d *Decoder, typ reflect.Type, cls *ClassInfo) (interface{}, error) {
if typ.Kind() != reflect.Struct {
return nil, perrors.Errorf("wrong type expect Struct but get:%s", typ.String())
}
Expand Down
2 changes: 1 addition & 1 deletion java_unknown_exception.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (

var exceptionCheckMutex sync.Mutex

func checkAndGetException(cls *classInfo) (*structInfo, bool) {
func checkAndGetException(cls *ClassInfo) (*structInfo, bool) {
if len(cls.fieldNameList) < 4 {
return nil, false
}
Expand Down
4 changes: 2 additions & 2 deletions java_unknown_exception_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
)

func TestCheckAndGetException(t *testing.T) {
clazzInfo1 := &classInfo{
clazzInfo1 := &ClassInfo{
javaName: "com.test.UserDefinedException",
fieldNameList: []string{"detailMessage", "code", "suppressedExceptions", "stackTrace", "cause"},
}
Expand All @@ -36,7 +36,7 @@ func TestCheckAndGetException(t *testing.T) {
assert.Equal(t, s.javaName, "com.test.UserDefinedException")
assert.Equal(t, s.goName, "github.com/apache/dubbo-go-hessian2/hessian.UnknownException")

clazzInfo2 := &classInfo{
clazzInfo2 := &ClassInfo{
javaName: "com.test.UserDefinedException",
fieldNameList: []string{"detailMessage", "code", "suppressedExceptions", "cause"},
}
Expand Down
9 changes: 7 additions & 2 deletions map.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ package hessian
import (
"io"
"reflect"
)

import (
perrors "github.com/pkg/errors"
)

Expand Down Expand Up @@ -110,6 +108,13 @@ func (e *Encoder) encMap(m interface{}) error {
return nil
}

// check whether it should encode the map as class.
if mm, ok := m.(map[string]interface{}); ok {
if _, ok = mm[ClassKey]; ok {
return e.EncodeMapClass(mm)
}
}

value = UnpackPtrValue(value)
// check nil map
if value.Kind() == reflect.Ptr && !value.Elem().IsValid() {
Expand Down
Loading