From 93282e101d3804c59716bc3b30f7d43221ee8c43 Mon Sep 17 00:00:00 2001 From: haiyizxx Date: Mon, 6 Jan 2025 03:26:46 -0500 Subject: [PATCH] refactor: improve edge case handling for recursion limits (#22988) Co-authored-by: Alex | Skip --- CHANGELOG.md | 2 ++ codec/types/interface_registry.go | 4 ++-- codec/unknownproto/unknown_fields.go | 2 +- x/tx/decode/unknown.go | 2 +- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f0d68750ee8c..0816734bfcf0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,6 +48,8 @@ Every module contains its own CHANGELOG.md. Please refer to the module you are i ### Improvements +* (codec) [#22988](https://github.com/cosmos/cosmos-sdk/pull/22988) Improve edge case handling for recursion limits. + ### Bug Fixes * (query) [23002](https://github.com/cosmos/cosmos-sdk/pull/23002) Fix collection filtered pagination. diff --git a/codec/types/interface_registry.go b/codec/types/interface_registry.go index 34d59bd33a46..68ed8c885d9f 100644 --- a/codec/types/interface_registry.go +++ b/codec/types/interface_registry.go @@ -274,10 +274,10 @@ func (r statefulUnpacker) cloneForRecursion() *statefulUnpacker { // UnpackAny deserializes a protobuf Any message into the provided interface, ensuring the interface is a pointer. // It applies stateful constraints such as max depth and call limits, and unpacks interfaces if required. func (r *statefulUnpacker) UnpackAny(any *Any, iface interface{}) error { - if r.maxDepth == 0 { + if r.maxDepth <= 0 { return errors.New("max depth exceeded") } - if r.maxCalls.count == 0 { + if r.maxCalls.count <= 0 { return errors.New("call limit exceeded") } // here we gracefully handle the case in which `any` itself is `nil`, which may occur in message decoding diff --git a/codec/unknownproto/unknown_fields.go b/codec/unknownproto/unknown_fields.go index 17b8f7e424ee..a60f2f9caac8 100644 --- a/codec/unknownproto/unknown_fields.go +++ b/codec/unknownproto/unknown_fields.go @@ -54,7 +54,7 @@ func doRejectUnknownFields( if len(bz) == 0 { return hasUnknownNonCriticals, nil } - if recursionLimit == 0 { + if recursionLimit <= 0 { return false, errors.New("recursion limit reached") } diff --git a/x/tx/decode/unknown.go b/x/tx/decode/unknown.go index 6d7b9616b2fb..ce608b32a4ba 100644 --- a/x/tx/decode/unknown.go +++ b/x/tx/decode/unknown.go @@ -47,7 +47,7 @@ func doRejectUnknownFields( if len(bz) == 0 { return hasUnknownNonCriticals, nil } - if recursionLimit == 0 { + if recursionLimit <= 0 { return false, errors.New("recursion limit reached") }