diff --git a/README.md b/README.md index 987027b5..b454844c 100644 --- a/README.md +++ b/README.md @@ -368,6 +368,20 @@ type Dog struct { } ``` +## Strict Mode + +Default, hessian2 will decode an object to map if it's not being registered. +If you don't want that, change the decoder to strict mode as following, +and it will return error when meeting unregistered object. + +```go +e := hessian.NewDecoder(bytes) +e.Strict = true // set to strict mode, default is false + +// or +e := hessian.NewStrictDecoder(bytes) +``` + ## Tools ### tools/gen-go-enum diff --git a/decode.go b/decode.go index 89af8532..d0b67a81 100644 --- a/decode.go +++ b/decode.go @@ -34,8 +34,23 @@ type Decoder struct { refs []interface{} // record type refs, both list and map need it typeRefs *TypeRefs - classInfoList []*classInfo + classInfoList []*ClassInfo isSkip bool + + // In strict mode, a class data can be decoded only when the class is registered, otherwise error returned. + // In non-strict mode, a class data will be decoded to a map when the class is not registered. + // The default is non-strict mode, user can change it as required. + Strict bool +} + +// FindClassInfo find ClassInfo for the given name in decoder class info list. +func (d *Decoder) FindClassInfo(javaName string) *ClassInfo { + for _, info := range d.classInfoList { + if info.javaName == javaName { + return info + } + } + return nil } // Error part @@ -49,6 +64,16 @@ func NewDecoder(b []byte) *Decoder { return &Decoder{reader: bufio.NewReader(bytes.NewReader(b)), typeRefs: &TypeRefs{records: map[string]bool{}}} } +// NewStrictDecoder generates a strict mode decoder instance. +// In strict mode, all target class must be registered. +func NewStrictDecoder(b []byte) *Decoder { + return &Decoder{ + reader: bufio.NewReader(bytes.NewReader(b)), + typeRefs: &TypeRefs{records: map[string]bool{}}, + Strict: true, + } +} + // NewDecoderSize generate a decoder instance. func NewDecoderSize(b []byte, size int) *Decoder { return &Decoder{reader: bufio.NewReaderSize(bytes.NewReader(b), size), typeRefs: &TypeRefs{records: map[string]bool{}}} diff --git a/decode_test.go b/decode_test.go index c3f2b0b1..99db40b6 100644 --- a/decode_test.go +++ b/decode_test.go @@ -145,6 +145,16 @@ func testDecodeFrameworkFunc(t *testing.T, method string, expected func(interfac expected(r) } +func mustDecodeObject(t *testing.T, b []byte) interface{} { + d := NewDecoder(b) + res, err := d.Decode() + if err != nil { + t.Error(err) + t.FailNow() + } + return res +} + func TestUserDefindeException(t *testing.T) { expect := &UnknownException{ DetailMessage: "throw UserDefindException", diff --git a/encode.go b/encode.go index 0daf3e05..802c2fb8 100644 --- a/encode.go +++ b/encode.go @@ -33,11 +33,21 @@ import ( // Encoder struct type Encoder struct { - classInfoList []*classInfo + classInfoList []*ClassInfo buffer []byte refMap map[unsafe.Pointer]_refElem } +// classIndex find the index of the given java name in encoder class info list. +func (e *Encoder) classIndex(javaName string) int { + for i := range e.classInfoList { + if javaName == e.classInfoList[i].javaName { + return i + } + } + return -1 +} + // NewEncoder generate an encoder instance func NewEncoder() *Encoder { buffer := make([]byte, 64) diff --git a/hessian.go b/hessian.go index dabac5a3..0549984b 100644 --- a/hessian.go +++ b/hessian.go @@ -74,7 +74,7 @@ func NewHessianCodec(reader *bufio.Reader) *HessianCodec { } } -// NewHessianCodec generate a new hessian codec instance +// NewHessianCodecCustom generate a new hessian codec instance. func NewHessianCodecCustom(pkgType PackageType, reader *bufio.Reader, bodyLen int) *HessianCodec { return &HessianCodec{ pkgType: pkgType, @@ -129,15 +129,15 @@ func (h *HessianCodec) ReadHeader(header *DubboHeader) error { return perrors.Errorf("serialization ID:%v", header.SerialID) } - flag := buf[2] & FLAG_EVENT - if flag != Zero { + headerFlag := buf[2] & FLAG_EVENT + if headerFlag != Zero { header.Type |= PackageHeartbeat } - flag = buf[2] & FLAG_REQUEST - if flag != Zero { + headerFlag = buf[2] & FLAG_REQUEST + if headerFlag != Zero { header.Type |= PackageRequest - flag = buf[2] & FLAG_TWOWAY - if flag != Zero { + headerFlag = buf[2] & FLAG_TWOWAY + if headerFlag != Zero { header.Type |= PackageRequest_TwoWay } } else { @@ -197,7 +197,7 @@ func (h *HessianCodec) ReadBody(rspObj interface{}) error { case PackageRequest | PackageHeartbeat, PackageResponse | PackageHeartbeat: case PackageRequest: if rspObj != nil { - if err = unpackRequestBody(NewDecoder(buf[:]), rspObj); err != nil { + if err = unpackRequestBody(NewStrictDecoder(buf[:]), rspObj); err != nil { return perrors.WithStack(err) } } @@ -212,7 +212,7 @@ func (h *HessianCodec) ReadBody(rspObj interface{}) error { return nil } -// ignore body, but only read attachments +// ReadAttachments ignore body, but only read attachments func (h *HessianCodec) ReadAttachments() (map[string]string, error) { if h.reader.Buffered() < h.bodyLen { return nil, ErrBodyNotEnough diff --git a/hessian_test.go b/hessian_test.go index a65c61fc..5a46b13b 100644 --- a/hessian_test.go +++ b/hessian_test.go @@ -114,39 +114,52 @@ func doTestResponse(t *testing.T, packageType PackageType, responseStatus byte, func TestResponse(t *testing.T) { caseObj := Case{A: "a", B: 1} decodedResponse := &Response{} + RegisterPOJO(&caseObj) arr := []*Case{&caseObj} - var arrRes []interface{} - decodedResponse.RspObj = &arrRes + decodedResponse.RspObj = nil doTestResponse(t, PackageResponse, Response_OK, arr, decodedResponse, func() { + arrRes, ok := decodedResponse.RspObj.([]*Case) + if !ok { + t.Errorf("expect []*Case, but get %s", reflect.TypeOf(decodedResponse.RspObj).String()) + return + } assert.Equal(t, 1, len(arrRes)) assert.Equal(t, &caseObj, arrRes[0]) }) - decodedResponse.RspObj = &Case{} - doTestResponse(t, PackageResponse, Response_OK, &Case{A: "a", B: 1}, decodedResponse, nil) + doTestResponse(t, PackageResponse, Response_OK, &caseObj, decodedResponse, func() { + assert.Equal(t, &caseObj, decodedResponse.RspObj) + }) s := "ok!!!!!" - strObj := "" - decodedResponse.RspObj = &strObj - doTestResponse(t, PackageResponse, Response_OK, s, decodedResponse, nil) + doTestResponse(t, PackageResponse, Response_OK, s, decodedResponse, func() { + assert.Equal(t, s, decodedResponse.RspObj) + }) - var intObj int64 - decodedResponse.RspObj = &intObj - doTestResponse(t, PackageResponse, Response_OK, int64(3), decodedResponse, nil) + doTestResponse(t, PackageResponse, Response_OK, int64(3), decodedResponse, func() { + assert.Equal(t, int64(3), decodedResponse.RspObj) + }) - boolObj := false - decodedResponse.RspObj = &boolObj - doTestResponse(t, PackageResponse, Response_OK, true, decodedResponse, nil) + doTestResponse(t, PackageResponse, Response_OK, true, decodedResponse, func() { + assert.Equal(t, true, decodedResponse.RspObj) + }) - strObj = "" - decodedResponse.RspObj = &strObj - doTestResponse(t, PackageResponse, Response_SERVER_ERROR, "error!!!!!", decodedResponse, nil) + errorMsg := "error!!!!!" + decodedResponse.RspObj = nil + doTestResponse(t, PackageResponse, Response_SERVER_ERROR, errorMsg, decodedResponse, func() { + assert.Equal(t, "java exception:error!!!!!", decodedResponse.Exception.Error()) + }) + decodedResponse.RspObj = nil + decodedResponse.Exception = nil mapObj := map[string][]*Case{"key": {&caseObj}} - mapRes := map[interface{}]interface{}{} - decodedResponse.RspObj = &mapRes doTestResponse(t, PackageResponse, Response_OK, mapObj, decodedResponse, func() { + mapRes, ok := decodedResponse.RspObj.(map[interface{}]interface{}) + if !ok { + t.Errorf("expect map[string][]*Case, but get %s", reflect.TypeOf(decodedResponse.RspObj).String()) + return + } c, ok := mapRes["key"] if !ok { assert.FailNow(t, "no key in decoded response map") @@ -211,7 +224,8 @@ func TestHessianCodec_ReadAttachments(t *testing.T) { t.Log(h) err = codecR1.ReadBody(body) - assert.Equal(t, "can not find go type name com.test.caseb in registry", err.Error()) + assert.NoError(t, err) + // assert.Equal(t, "can not find go type name com.test.caseb in registry", err.Error()) attrs, err := codecR2.ReadAttachments() assert.NoError(t, err) assert.Equal(t, "2.6.4", attrs[DUBBO_VERSION_KEY]) diff --git a/int.go b/int.go index 06cdb719..1c5fd7b3 100644 --- a/int.go +++ b/int.go @@ -110,7 +110,7 @@ func (d *Decoder) decInt32(flag int32) (int32, error) { } } -func (d *Encoder) encTypeInt32(b []byte, p interface{}) ([]byte, error) { +func (e *Encoder) encTypeInt32(b []byte, p interface{}) ([]byte, error) { value := reflect.ValueOf(p) if PackPtr(value).IsNil() { return EncNull(b), nil diff --git a/java_collection.go b/java_collection.go index f9c0ce53..ffab2860 100644 --- a/java_collection.go +++ b/java_collection.go @@ -85,7 +85,7 @@ func (JavaCollectionSerializer) EncObject(e *Encoder, vv POJO) error { return nil } -func (JavaCollectionSerializer) DecObject(d *Decoder, typ reflect.Type, cls *classInfo) (interface{}, error) { +func (JavaCollectionSerializer) DecObject(d *Decoder, typ reflect.Type, cls *ClassInfo) (interface{}, error) { // for the java impl of hessian encode collections as list, which will not be decoded as object in go impl, this method should not be called return nil, perrors.New("unexpected collection decode call") } diff --git a/java_sql_time.go b/java_sql_time.go index 8edee13d..b86cc4c8 100644 --- a/java_sql_time.go +++ b/java_sql_time.go @@ -49,10 +49,9 @@ type JavaSqlTimeSerializer struct{} // nolint func (JavaSqlTimeSerializer) EncObject(e *Encoder, vv POJO) error { var ( - i int idx int err error - clsDef *classInfo + clsDef *ClassInfo className string ptrV reflect.Value ) @@ -78,13 +77,8 @@ func (JavaSqlTimeSerializer) EncObject(e *Encoder, vv POJO) error { } // write object definition - idx = -1 - for i = range e.classInfoList { - if v.JavaClassName() == e.classInfoList[i].javaName { - idx = i - break - } - } + idx = e.classIndex(v.JavaClassName()) + if idx == -1 { idx, ok = checkPOJORegistry(vv) if !ok { @@ -114,7 +108,7 @@ func (JavaSqlTimeSerializer) EncObject(e *Encoder, vv POJO) error { } // nolint -func (JavaSqlTimeSerializer) DecObject(d *Decoder, typ reflect.Type, cls *classInfo) (interface{}, error) { +func (JavaSqlTimeSerializer) DecObject(d *Decoder, typ reflect.Type, cls *ClassInfo) (interface{}, error) { if typ.Kind() != reflect.Struct { return nil, perrors.Errorf("wrong type expect Struct but get:%s", typ.String()) } diff --git a/java_unknown_exception.go b/java_unknown_exception.go index 035fb941..0bbcce80 100644 --- a/java_unknown_exception.go +++ b/java_unknown_exception.go @@ -28,7 +28,7 @@ import ( var exceptionCheckMutex sync.Mutex -func checkAndGetException(cls *classInfo) (*structInfo, bool) { +func checkAndGetException(cls *ClassInfo) (*structInfo, bool) { if len(cls.fieldNameList) < 4 { return nil, false } diff --git a/java_unknown_exception_test.go b/java_unknown_exception_test.go index 264beafd..4fe9493c 100644 --- a/java_unknown_exception_test.go +++ b/java_unknown_exception_test.go @@ -26,7 +26,7 @@ import ( ) func TestCheckAndGetException(t *testing.T) { - clazzInfo1 := &classInfo{ + clazzInfo1 := &ClassInfo{ javaName: "com.test.UserDefinedException", fieldNameList: []string{"detailMessage", "code", "suppressedExceptions", "stackTrace", "cause"}, } @@ -36,7 +36,7 @@ func TestCheckAndGetException(t *testing.T) { assert.Equal(t, s.javaName, "com.test.UserDefinedException") assert.Equal(t, s.goName, "github.com/apache/dubbo-go-hessian2/hessian.UnknownException") - clazzInfo2 := &classInfo{ + clazzInfo2 := &ClassInfo{ javaName: "com.test.UserDefinedException", fieldNameList: []string{"detailMessage", "code", "suppressedExceptions", "cause"}, } diff --git a/map.go b/map.go index fd0b95ca..4f144b83 100644 --- a/map.go +++ b/map.go @@ -20,9 +20,7 @@ package hessian import ( "io" "reflect" -) -import ( perrors "github.com/pkg/errors" ) @@ -110,6 +108,13 @@ func (e *Encoder) encMap(m interface{}) error { return nil } + // check whether it should encode the map as class. + if mm, ok := m.(map[string]interface{}); ok { + if _, ok = mm[ClassKey]; ok { + return e.EncodeMapClass(mm) + } + } + value = UnpackPtrValue(value) // check nil map if value.Kind() == reflect.Ptr && !value.Elem().IsValid() { diff --git a/object.go b/object.go index 27dd5309..54526ae4 100644 --- a/object.go +++ b/object.go @@ -103,7 +103,7 @@ func (e *Encoder) encObject(v interface{}) error { idx int num int err error - clsDef *classInfo + clsDef *ClassInfo ) pojo, isPojo := v.(POJO) // get none pojo JavaClassName @@ -209,6 +209,82 @@ func (e *Encoder) encObject(v interface{}) error { return nil } +// EncodeMapClass encode a map as object, which MUST contains a key _class and its value is the target class name. +func (e *Encoder) EncodeMapClass(m map[string]interface{}) error { + clsName, ok := m[ClassKey] + if !ok { + return perrors.New("no _class key map") + } + + className, ok := clsName.(string) + if !ok { + return perrors.Errorf("expect string class name, but get %v", reflect.TypeOf(clsName)) + } + + return e.EncodeMapAsClass(className, m) +} + +// EncodeMapAsClass encode a map as object of given class name. +func (e *Encoder) EncodeMapAsClass(className string, m map[string]interface{}) error { + idx := e.classIndex(className) + + if idx == -1 { + var clsDef *ClassInfo + s, ok := getStructInfo(className) + if ok { + clsDef = pojoRegistry.classInfoList[s.index] + } else { + var err error + clsDef, err = buildMapClassDef(className, m) + if err != nil { + return err + } + } + idx = len(e.classInfoList) + e.classInfoList = append(e.classInfoList, clsDef) + e.buffer = append(e.buffer, clsDef.buffer...) + } + + return e.encodeMapAsIndexedClass(idx, m) +} + +// EncodeMapAsObject encode a map as the given class defined object. +// Sometimes a class may not being registered in hessian, but it can be decoded from serialized data, +// and the ClassInfo can be found in Decoder by calling Decoder.FindClassInfo. +func (e *Encoder) EncodeMapAsObject(clsDef *ClassInfo, m map[string]interface{}) error { + idx := e.classIndex(clsDef.javaName) + if idx == -1 { + idx = len(e.classInfoList) + e.classInfoList = append(e.classInfoList, clsDef) + if len(clsDef.buffer) == 0 { + clsDef.initDefBuffer() + } + e.buffer = append(e.buffer, clsDef.buffer...) + } + return e.encodeMapAsIndexedClass(idx, m) +} + +// encodeMapAsIndexedClass encode a map as the defined class at the given index in the encoder class list. +func (e *Encoder) encodeMapAsIndexedClass(idx int, m map[string]interface{}) error { + // write object instance + if byte(idx) <= OBJECT_DIRECT_MAX { + e.buffer = encByte(e.buffer, byte(idx)+BC_OBJECT_DIRECT) + } else { + e.buffer = encByte(e.buffer, BC_OBJECT) + e.buffer = encInt32(e.buffer, int32(idx)) + } + + cls := e.classInfoList[idx] + var err error + for i := 0; i < len(cls.fieldNameList); i++ { + fieldName := cls.fieldNameList[i] + if err = e.Encode(m[fieldName]); err != nil { + return perrors.Wrapf(err, "failed to encode field: %s, %+v", fieldName, m[fieldName]) + } + } + return nil +} + ///////////////////////////////////////// // Object ///////////////////////////////////////// @@ -304,7 +380,7 @@ func (d *Decoder) decClassDef() (interface{}, error) { fieldList[i] = fieldName } - return &classInfo{javaName: clsName, fieldNameList: fieldList}, nil + return &ClassInfo{javaName: clsName, fieldNameList: fieldList}, nil } type fieldInfo struct { @@ -374,7 +450,7 @@ func findField(name string, typ reflect.Type) ([]int, *reflect.StructField, erro return []int{}, nil, perrors.Errorf("failed to find field %s", name) } -func (d *Decoder) decInstance(typ reflect.Type, cls *classInfo) (interface{}, error) { +func (d *Decoder) decInstance(typ reflect.Type, cls *ClassInfo) (interface{}, error) { if typ.Kind() != reflect.Struct { return nil, perrors.Errorf("wrong type expect Struct but get:%s", typ.String()) } @@ -553,14 +629,14 @@ func (d *Decoder) decInstance(typ reflect.Type, cls *classInfo) (interface{}, er return vRef.Interface(), nil } -func (d *Decoder) appendClsDef(cd *classInfo) { +func (d *Decoder) appendClsDef(cd *ClassInfo) { d.classInfoList = append(d.classInfoList, cd) } -func (d *Decoder) getStructDefByIndex(idx int) (reflect.Type, *classInfo, error) { +func (d *Decoder) getStructDefByIndex(idx int) (reflect.Type, *ClassInfo, error) { var ( ok bool - cls *classInfo + cls *ClassInfo s *structInfo err error ) @@ -575,7 +651,7 @@ func (d *Decoder) getStructDefByIndex(idx int) (reflect.Type, *classInfo, error) if s, ok = checkAndGetException(cls); ok { return s.typ, cls, nil } - if !d.isSkip { + if !d.isSkip && d.Strict { err = perrors.Errorf("can not find go type name %s in registry", cls.javaName) } return nil, cls, err @@ -607,13 +683,13 @@ func (d *Decoder) decEnum(javaName string, flag int32) (JavaEnum, error) { } // skip this object -func (d *Decoder) skip(cls *classInfo) error { - len := len(cls.fieldNameList) - if len < 1 { +func (d *Decoder) skip(cls *ClassInfo) error { + fieldLen := len(cls.fieldNameList) + if fieldLen < 1 { return nil } - for i := 0; i < len; i++ { + for i := 0; i < fieldLen; i++ { // skip class fields. if _, err := d.DecodeValue(); err != nil { return err @@ -629,7 +705,7 @@ func (d *Decoder) decObject(flag int32) (interface{}, error) { idx int32 err error typ reflect.Type - cls *classInfo + cls *ClassInfo ) if flag != TAG_READ { @@ -648,7 +724,7 @@ func (d *Decoder) decObject(flag int32) (interface{}, error) { if decErr != nil { return nil, perrors.Wrap(decErr, "decObject->decClassDef byte double") } - cls, _ = clsDef.(*classInfo) + cls, _ = clsDef.(*ClassInfo) // add to slice d.appendClsDef(cls) @@ -665,7 +741,10 @@ func (d *Decoder) decObject(flag int32) (interface{}, error) { return nil, err } if typ == nil { - return nil, d.skip(cls) + if d.isSkip { + return nil, d.skip(cls) + } + return d.decClassToMap(cls) } if typ.Implements(javaEnumType) { return d.decEnum(cls.javaName, TAG_READ) @@ -683,7 +762,10 @@ func (d *Decoder) decObject(flag int32) (interface{}, error) { return nil, err } if typ == nil { - return nil, d.skip(cls) + if d.isSkip { + return nil, d.skip(cls) + } + return d.decClassToMap(cls) } if typ.Implements(javaEnumType) { return d.decEnum(cls.javaName, TAG_READ) @@ -699,3 +781,22 @@ func (d *Decoder) decObject(flag int32) (interface{}, error) { return nil, perrors.Errorf("decObject illegal object type tag:%+v", tag) } } + +func (d *Decoder) decClassToMap(cls *ClassInfo) (interface{}, error) { + vMap := make(map[string]interface{}, len(cls.fieldNameList)) + vMap[ClassKey] = cls.javaName + + d.appendRefs(vMap) + + for i := 0; i < len(cls.fieldNameList); i++ { + fieldName := cls.fieldNameList[i] + + fieldValue, decErr := d.DecodeValue() + if decErr != nil { + return nil, perrors.Wrapf(decErr, "decClassToMap -> decode field name:%s", fieldName) + } + vMap[fieldName] = EnsureRawAny(fieldValue) + } + + return vMap, nil +} diff --git a/object_test.go b/object_test.go index d6d67b19..b7afa118 100644 --- a/object_test.go +++ b/object_test.go @@ -23,9 +23,7 @@ import ( "reflect" "testing" "time" -) -import ( "github.com/stretchr/testify/assert" ) @@ -224,6 +222,153 @@ func TestIssue6(t *testing.T) { } } +func TestDecClassToMap(t *testing.T) { + name := UserName{ + FirstName: "John", + LastName: "Doe", + } + person := Person{ + UserName: name, + Age: 18, + Sex: true, + } + + worker1 := &Worker{ + Person: person, + CurJob: JOB{Title: "cto", Company: "facebook"}, + Jobs: []JOB{ + {Title: "manager", Company: "google"}, + {Title: "ceo", Company: "microsoft"}, + }, + } + + t.Logf("worker1: %v", worker1) + + e := NewEncoder() + encErr := e.Encode(worker1) + if encErr != nil { + t.Fatalf("encode(worker:%#v) = error:%s", worker1, encErr) + } + data := e.Buffer() + t.Logf("data: %s", data) + + // unRegisterPOJO before decode, so that to decode to map + unRegisterPOJO(name) + unRegisterPOJO(person) + unRegisterPOJO(worker1) + unRegisterPOJO(&worker1.Jobs[0]) + + // strict mode + d := NewDecoder(data) + d.Strict = true + res, err := d.Decode() + if err == nil { + t.Error("after unregister pojo, decoding should return error for strict mode") + t.FailNow() + } + assert.Nil(t, res) + + // non-strict mode + d = NewDecoder(data) + res, err = d.Decode() + if err != nil { + t.Error(err) + t.FailNow() + } + t.Logf("type of decode object:%v", reflect.TypeOf(res)) + + worker2, ok := res.(map[string]interface{}) + if !ok { + t.Fatalf("res:%#v should be a map for non-strict mode", res) + } + + t.Logf("worker2: %v", worker2) + + // register pojo again + RegisterPOJO(name) + RegisterPOJO(person) + RegisterPOJO(worker1) + RegisterPOJO(&worker1.Jobs[0]) + + // encode the map to object again + e = NewEncoder() + err = e.Encode(worker2) + if err != nil { + t.Error(err) + t.FailNow() + } + + data = e.Buffer() + t.Logf("data: %s", data) + + // decode the encoded map data to struct + d = NewDecoder(data) + res, err = d.Decode() + if err != nil { + t.Error(err) + t.FailNow() + } + t.Logf("type of decode object:%v", reflect.TypeOf(res)) + + worker3, ok := res.(*Worker) + if !ok { + t.Fatalf("res:%#v should be a worker type", res) + } + + t.Logf("worker3: %v", worker3) + if !reflect.DeepEqual(worker1, worker3) { + t.Fatal("worker1 not equal to worker3!") + } +} + +func TestEncodeMapToObject(t *testing.T) { + name := &UserName{ + FirstName: "John", + LastName: "Doe", + } + + // note: the first letter of the keys MUST lowercase. + m := map[string]interface{}{ + "firstName": "John", + "lastName": "Doe", + } + + e := NewEncoder() + encErr := e.EncodeMapAsClass(name.JavaClassName(), m) + if encErr != nil { + t.Error(encErr) + t.FailNow() + } + + // register for decode map to object. + RegisterPOJO(name) + res := mustDecodeObject(t, e.Buffer()) + assert.True(t, reflect.DeepEqual(name, res)) + + // unregister for encode map again. + UnRegisterPOJOs(name) + + // note: the map contains the class key. + m = map[string]interface{}{ + ClassKey: name.JavaClassName(), + "firstName": "John", + "lastName": "Doe", + } + + // try to encode again + e = NewEncoder() + encErr = e.EncodeMapClass(m) + if encErr != nil { + t.Error(encErr) + t.FailNow() + } + + // register for decode map to object. + RegisterPOJO(name) + res = mustDecodeObject(t, e.Buffer()) + assert.True(t, reflect.DeepEqual(name, res)) +} + type A0 struct{} // JavaClassName java fully qualified path diff --git a/pojo.go b/pojo.go index 273ef7c7..c74394f6 100644 --- a/pojo.go +++ b/pojo.go @@ -32,6 +32,8 @@ import ( // invalid consts const ( InvalidJavaEnum JavaEnum = -1 + + ClassKey = "_class" ) // struct filed tag of hessian @@ -72,7 +74,7 @@ type JavaEnumClass struct { name string } -type classInfo struct { +type ClassInfo struct { javaName string fieldNameList []string buffer []byte // encoded buffer @@ -89,7 +91,7 @@ type structInfo struct { // POJORegistry pojo registry struct type POJORegistry struct { sync.RWMutex - classInfoList []*classInfo // {class name, field name list...} list + classInfoList []*ClassInfo // {class name, field name list...} list j2g map[string]string // java class name --> go struct name registry map[string]*structInfo // go class name --> go struct info } @@ -103,6 +105,19 @@ var ( javaEnumType = reflect.TypeOf((*POJOEnum)(nil)).Elem() ) +// initDefBuffer initial the class definition buffer, which can be used repeatedly. +func (c *ClassInfo) initDefBuffer() { + if len(c.buffer) == 0 { + c.buffer = encByte(c.buffer, BC_OBJECT_DEF) + c.buffer = encString(c.buffer, c.javaName) + c.buffer = encInt32(c.buffer, int32(len(c.fieldNameList))) + + for _, fieldName := range c.fieldNameList { + c.buffer = encString(c.buffer, fieldName) + } + } +} + // struct parsing func showPOJORegistry() { pojoRegistry.Lock() @@ -140,7 +155,7 @@ func RegisterPOJOMapping(javaClassName string, o interface{}) int { bBody []byte fieldList []string sttInfo structInfo - clsDef classInfo + clsDef ClassInfo ) sttInfo.typ = obtainValueType(o) @@ -199,9 +214,9 @@ func RegisterPOJOMapping(javaClassName string, o interface{}) int { bHeader = encInt32(bHeader, int32(len(fieldList))) // prepare classDef - clsDef = classInfo{javaName: sttInfo.javaName, fieldNameList: fieldList} + clsDef = ClassInfo{javaName: sttInfo.javaName, fieldNameList: fieldList} - // merge header and body of objectDef into buffer of classInfo + // merge header and body of objectDef into buffer of ClassInfo clsDef.buffer = append(bHeader, bBody...) sttInfo.index = len(pojoRegistry.classInfoList) @@ -227,14 +242,14 @@ func unRegisterPOJO(o POJO) int { goName := GetGoType(o) - if structInfo, ok := pojoRegistry.registry[goName]; ok { - delete(pojoRegistry.j2g, structInfo.javaName) - listTypeNameMapper.Delete(structInfo.goName) + if pojoStructInfo, ok := pojoRegistry.registry[goName]; ok { + delete(pojoRegistry.j2g, pojoStructInfo.javaName) + listTypeNameMapper.Delete(pojoStructInfo.goName) // remove registry cache. - delete(pojoRegistry.registry, structInfo.goName) + delete(pojoRegistry.registry, pojoStructInfo.goName) // don't remove registry classInfoList, // indexes of registered pojo may be affected. - return structInfo.index + return pojoStructInfo.index } return -1 @@ -300,7 +315,7 @@ func RegisterJavaEnum(o POJOEnum) int { f string l []string t structInfo - c classInfo + c ClassInfo v reflect.Value ) @@ -331,7 +346,7 @@ func RegisterJavaEnum(o POJOEnum) int { l = append(l, f) b = encString(b, f) - c = classInfo{javaName: t.javaName, fieldNameList: l} + c = ClassInfo{javaName: t.javaName, fieldNameList: l} c.buffer = append(c.buffer, b[:]...) t.index = len(pojoRegistry.classInfoList) pojoRegistry.classInfoList = append(pojoRegistry.classInfoList, &c) @@ -369,27 +384,22 @@ func loadPOJORegistry(v interface{}) (*structInfo, bool) { // @typeName is class's java name func getStructInfo(javaName string) (*structInfo, bool) { - var ( - ok bool - g string - s *structInfo - ) - pojoRegistry.RLock() - g, ok = pojoRegistry.j2g[javaName] - if ok { - s, ok = pojoRegistry.registry[g] + defer pojoRegistry.RUnlock() + + if g, ok := pojoRegistry.j2g[javaName]; ok { + s, b := pojoRegistry.registry[g] + return s, b } - pojoRegistry.RUnlock() - return s, ok + return nil, false } -func getStructDefByIndex(idx int) (reflect.Type, *classInfo, error) { +func getStructDefByIndex(idx int) (reflect.Type, *ClassInfo, error) { var ( ok bool clsName string - cls *classInfo + cls *ClassInfo s *structInfo ) @@ -439,3 +449,30 @@ func lowerCamelCase(s string) string { runes[0] = unicode.ToLower(runes[0]) return string(runes) } + +// buildMapClassDef build ClassInfo from map keys. +func buildMapClassDef(javaName string, m map[string]interface{}) (*ClassInfo, error) { + if javaName == "" { + var ok bool + javaName, ok = m[ClassKey].(string) + if !ok { + return nil, perrors.Errorf("no java name to build class info from map: %v", m) + } + } + + info := &ClassInfo{javaName: javaName} + + _, existClassKey := m[ClassKey] + + for fieldName := range m { + if existClassKey && fieldName == ClassKey { + continue + } + + info.fieldNameList = append(info.fieldNameList, fieldName) + } + + info.initDefBuffer() + + return info, nil +} diff --git a/response.go b/response.go index 2202c865..66235d79 100644 --- a/response.go +++ b/response.go @@ -23,14 +23,9 @@ import ( "reflect" "strconv" "strings" -) - -import ( - perrors "github.com/pkg/errors" -) -import ( "github.com/apache/dubbo-go-hessian2/java_exception" + perrors "github.com/pkg/errors" ) // Response dubbo response @@ -163,14 +158,14 @@ func unpackResponseBody(decoder *Decoder, resp interface{}) error { switch rspType { case RESPONSE_WITH_EXCEPTION, RESPONSE_WITH_EXCEPTION_WITH_ATTACHMENTS: - expt, err := decoder.Decode() - if err != nil { - return perrors.WithStack(err) + expt, decErr := decoder.Decode() + if decErr != nil { + return perrors.WithStack(decErr) } if rspType == RESPONSE_WITH_EXCEPTION_WITH_ATTACHMENTS { - attachments, err := decoder.Decode() - if err != nil { - return perrors.WithStack(err) + attachments, attErr := decoder.Decode() + if attErr != nil { + return perrors.WithStack(attErr) } if v, ok := attachments.(map[interface{}]interface{}); ok { atta := ToMapStringString(v) @@ -188,14 +183,14 @@ func unpackResponseBody(decoder *Decoder, resp interface{}) error { return nil case RESPONSE_VALUE, RESPONSE_VALUE_WITH_ATTACHMENTS: - rsp, err := decoder.Decode() - if err != nil { - return perrors.WithStack(err) + rsp, decErr := decoder.Decode() + if decErr != nil { + return perrors.WithStack(decErr) } if rspType == RESPONSE_VALUE_WITH_ATTACHMENTS { - attachments, err := decoder.Decode() - if err != nil { - return perrors.WithStack(err) + attachments, attErr := decoder.Decode() + if attErr != nil { + return perrors.WithStack(attErr) } if v, ok := attachments.(map[interface{}]interface{}); ok { response.Attachments = ToMapStringString(v) @@ -204,19 +199,15 @@ func unpackResponseBody(decoder *Decoder, resp interface{}) error { } } - // If the return value is nil, - // we should consider it normal - if rsp == nil { - return nil - } + response.RspObj = rsp - return perrors.WithStack(ReflectResponse(rsp, response.RspObj)) + return nil case RESPONSE_NULL_VALUE, RESPONSE_NULL_VALUE_WITH_ATTACHMENTS: if rspType == RESPONSE_NULL_VALUE_WITH_ATTACHMENTS { - attachments, err := decoder.Decode() - if err != nil { - return perrors.WithStack(err) + attachments, decErr := decoder.Decode() + if decErr != nil { + return perrors.WithStack(decErr) } if v, ok := attachments.(map[interface{}]interface{}); ok { atta := ToMapStringString(v) diff --git a/serialize.go b/serialize.go index ca738ea3..bf829da5 100644 --- a/serialize.go +++ b/serialize.go @@ -19,9 +19,7 @@ package hessian import ( "reflect" -) -import ( big "github.com/dubbogo/gost/math/big" ) @@ -37,7 +35,7 @@ func init() { type Serializer interface { EncObject(*Encoder, POJO) error - DecObject(*Decoder, reflect.Type, *classInfo) (interface{}, error) + DecObject(*Decoder, reflect.Type, *ClassInfo) (interface{}, error) } var serializerMap = make(map[string]Serializer, 16) @@ -53,7 +51,7 @@ func GetSerializer(javaClassName string) (Serializer, bool) { type IntegerSerializer struct{} -func (IntegerSerializer) DecObject(d *Decoder, typ reflect.Type, cls *classInfo) (interface{}, error) { +func (IntegerSerializer) DecObject(d *Decoder, typ reflect.Type, cls *ClassInfo) (interface{}, error) { bigInt, err := d.decInstance(typ, cls) if err != nil { return nil, err @@ -88,7 +86,7 @@ func (DecimalSerializer) EncObject(e *Encoder, v POJO) error { return e.encObject(decimal) } -func (DecimalSerializer) DecObject(d *Decoder, typ reflect.Type, cls *classInfo) (interface{}, error) { +func (DecimalSerializer) DecObject(d *Decoder, typ reflect.Type, cls *ClassInfo) (interface{}, error) { dec, err := d.decInstance(typ, cls) if err != nil { return nil, err