diff --git a/amt.go b/amt.go index 586a088..cc438ca 100644 --- a/amt.go +++ b/amt.go @@ -1,7 +1,6 @@ package amt import ( - "bytes" "context" "fmt" "math" @@ -149,11 +148,11 @@ func (r *Root) Set(ctx context.Context, i uint64, val cbg.CBORMarshaler) error { if val == nil { d.Raw = cbg.CborNull } else { - valueBuf := new(bytes.Buffer) - if err := val.MarshalCBOR(valueBuf); err != nil { + data, err := cborToBytes(val) + if err != nil { return err } - d.Raw = valueBuf.Bytes() + d.Raw = data } // where the index is greater than the number of elements we can fit into the diff --git a/util.go b/util.go index adf86a6..8d10265 100644 --- a/util.go +++ b/util.go @@ -1,6 +1,12 @@ package amt -import "math" +import ( + "bytes" + "math" + "sync" + + cbg "github.com/whyrusleeping/cbor-gen" +) // Given height 'height', how many nodes in a maximally full tree can we // build? (bitWidth^2)^height = width^height. If we pass in height+1 we can work @@ -13,3 +19,29 @@ func nodesForHeight(bitWidth uint, height int) uint64 { } return 1 << heightLogTwo } + +var bufferPool = sync.Pool{ + New: func() any { + return bytes.NewBuffer(nil) + }, +} + +func cborToBytes(val cbg.CBORMarshaler) ([]byte, error) { + // Temporary location to put values. We'll copy them to an exact-sized buffer when done. + valueBuf := bufferPool.Get().(*bytes.Buffer) + defer func() { + valueBuf.Reset() + bufferPool.Put(valueBuf) + }() + + if err := val.MarshalCBOR(valueBuf); err != nil { + return nil, err + } + + // Copy to shrink the allocation. + buf := valueBuf.Bytes() + cpy := make([]byte, len(buf)) + copy(cpy, buf) + + return cpy, nil +}