Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve log poller codec usage and optimise log event discriminator matching #1014

Merged
merged 8 commits into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 30 additions & 15 deletions pkg/solana/codec/codec_entry.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,16 @@ func NewAccountEntry(offchainName string, idlTypes AccountIDLTypes, includeDiscr
return nil, err
}

var discriminator *Discriminator
if includeDiscriminator {
discriminator = NewDiscriminator(idlTypes.Account.Name, true)
}

return newEntry(
offchainName,
idlTypes.Account.Name,
accCodec,
includeDiscriminator,
discriminator,
mod,
), nil
}
Expand All @@ -64,7 +69,7 @@ func NewPDAEntry(offchainName string, pdaTypeDef PDATypeDef, mod codec.Modifier,
offchainName,
offchainName, // PDA seeds do not correlate to anything on-chain so reusing offchain name
accCodec,
false,
nil,
mod,
), nil
}
Expand All @@ -85,7 +90,7 @@ func NewInstructionArgsEntry(offChainName string, idlTypes InstructionArgsIDLTyp
idlTypes.Instruction.Name,
instructionCodecArgs,
// Instruction arguments don't need a discriminator by default
false,
nil,
mod,
), nil
}
Expand All @@ -101,30 +106,40 @@ func NewEventArgsEntry(offChainName string, idlTypes EventIDLTypes, includeDiscr
return nil, err
}

var discriminator *Discriminator
if includeDiscriminator {
discriminator = NewDiscriminator(idlTypes.Event.Name, false)
}

return newEntry(
offChainName,
idlTypes.Event.Name,
eventCodec,
includeDiscriminator,
discriminator,
mod,
), nil
}

func newEntry(
genericName, chainSpecificName string,
typeCodec commonencodings.TypeCodec,
includeDiscriminator bool,
discriminator *Discriminator,
mod codec.Modifier,
) Entry {
return &entry{
genericName: genericName,
chainSpecificName: chainSpecificName,
reflectType: typeCodec.GetType(),
typeCodec: typeCodec,
mod: ensureModifier(mod),
includeDiscriminator: includeDiscriminator,
discriminator: *NewDiscriminator(chainSpecificName),
e := &entry{
genericName: genericName,
chainSpecificName: chainSpecificName,
reflectType: typeCodec.GetType(),
typeCodec: typeCodec,
mod: ensureModifier(mod),
}

if discriminator != nil {
e.discriminator = *discriminator
e.includeDiscriminator = true
}

return e
}

func createRefs(idlTypes IdlTypeDefSlice, builder commonencodings.Builder) *codecRefs {
Expand Down Expand Up @@ -175,8 +190,8 @@ func (e *entry) Decode(encoded []byte) (any, []byte, error) {
}

if !bytes.Equal(e.discriminator.hashPrefix, encoded[:discriminatorLength]) {
return nil, nil, fmt.Errorf("%w: encoded data has a bad discriminator %v for genericName: %q, chainSpecificName: %q",
commontypes.ErrInvalidType, encoded[:discriminatorLength], e.genericName, e.chainSpecificName)
return nil, nil, fmt.Errorf("%w: encoded data has a bad discriminator %v, expected %v, for genericName: %q, chainSpecificName: %q",
commontypes.ErrInvalidType, encoded[:discriminatorLength], e.discriminator.hashPrefix, e.genericName, e.chainSpecificName)
}

encoded = encoded[discriminatorLength:]
Expand Down
20 changes: 14 additions & 6 deletions pkg/solana/codec/codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,26 @@ func (it *codecInterfaceTester) GetAccountString(i int) string {
}

func (it *codecInterfaceTester) EncodeFields(t *testing.T, request *EncodeRequest) []byte {
if request.TestOn == TestItemType || request.TestOn == testutils.TestEventItem {
return encodeFieldsOnItem(t, request)
if request.TestOn == TestItemType {
return encodeFieldsOnItem(t, request, true)
} else if request.TestOn == testutils.TestEventItem {
return encodeFieldsOnItem(t, request, false)
}

return encodeFieldsOnSliceOrArray(t, request)
}

func encodeFieldsOnItem(t *testing.T, request *EncodeRequest) ocr2types.Report {
func encodeFieldsOnItem(t *testing.T, request *EncodeRequest, isAccount bool) ocr2types.Report {
buf := new(bytes.Buffer)
// The underlying TestItemAsAccount adds a discriminator by default while being Borsh encoded.
if err := testutils.EncodeRequestToTestItemAsAccount(request.TestStructs[0]).MarshalWithEncoder(bin.NewBorshEncoder(buf)); err != nil {
require.NoError(t, err)
// The underlying TestItem adds a discriminator by default while being Borsh encoded.
if isAccount {
if err := testutils.EncodeRequestToTestItemAsAccount(request.TestStructs[0]).MarshalWithEncoder(bin.NewBorshEncoder(buf)); err != nil {
require.NoError(t, err)
}
} else {
if err := testutils.EncodeRequestToTestItemAsEvent(request.TestStructs[0]).MarshalWithEncoder(bin.NewBorshEncoder(buf)); err != nil {
require.NoError(t, err)
}
}
return buf.Bytes()
}
Expand Down
60 changes: 57 additions & 3 deletions pkg/solana/codec/discriminator.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,19 @@ import (

const discriminatorLength = 8

func NewDiscriminator(name string) *Discriminator {
sum := sha256.Sum256([]byte("account:" + name))
return &Discriminator{hashPrefix: sum[:discriminatorLength]}
func NewDiscriminator(name string, isAccount bool) *Discriminator {
return &Discriminator{hashPrefix: NewDiscriminatorHashPrefix(name, isAccount)}
}

func NewDiscriminatorHashPrefix(name string, isAccount bool) []byte {
var sum [32]byte
if isAccount {
sum = sha256.Sum256([]byte("account:" + name))
} else {
sum = sha256.Sum256([]byte("event:" + name))
}

return sum[:discriminatorLength]
}

type Discriminator struct {
Expand Down Expand Up @@ -69,3 +79,47 @@ func (d Discriminator) Size(_ int) (int, error) {
func (d Discriminator) FixedSize() (int, error) {
return discriminatorLength, nil
}

type DiscriminatorExtractor struct {
b64Index [128]byte
}

// NewDiscriminatorExtractor is optimised to extract discriminators from base64 encoded strings faster than the base64 lib.
func NewDiscriminatorExtractor() DiscriminatorExtractor {
instance := DiscriminatorExtractor{}
const base64Chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
for i := 0; i < len(base64Chars); i++ {
instance.b64Index[base64Chars[i]] = byte(i)
}
return instance
}

// Extract most optimally (around 40% faster than std) decodes the first 8 bytes of a base64 encoded string, which corresponds to a Solana discriminator.
// Extract expects input of > 12 characters which 8 bytes are extracted from, if the input string is less than 12 characters, this will panic.
// Extract doesn't handle base64 padding because discriminators shouldn't have padding.
// If string contains non-Base64 characters (e.g., !, @, space) map to index 0 (ASCII 'A'), and won't be accurate.
func (e *DiscriminatorExtractor) Extract(data string) []byte {
var decodeBuffer [9]byte
d := decodeBuffer[:9]
s := data[:12]

// base64 decode
for i := 0; i < 3; i++ {
// decode base64 chars into associated byte
c1 := e.b64Index[s[0]]
c2 := e.b64Index[s[1]]
c3 := e.b64Index[s[2]]
c4 := e.b64Index[s[3]]

// reconstruct raw bytes
d[0] = (c1 << 2) | (c2 >> 4)
d[1] = (c2 << 4) | (c3 >> 2)
d[2] = (c3 << 6) | c4

// next 3 bytes and next 4 characters
d = d[3:]
s = s[4:]
}

return decodeBuffer[:discriminatorLength]
}
104 changes: 104 additions & 0 deletions pkg/solana/codec/discriminator_extractor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package codec

import (
"encoding/base64"
mathrand "math/rand"
"strings"
"testing"

"github.com/stretchr/testify/require"
)

func FuzzExtractorHappyPath(f *testing.F) {
// Seed with valid base64 discriminators
seeds := []struct {
Data string
}{
{"SGVsbG8gV29ybGQh"}, // Hello world!
{"AAAAAAAAAAAA"}, // Zero bytes
{"////////////"}, // Max value bytes
{"QUJDREVGR0hJSktM"}, // ABCDEFGHIJKL
}

for _, seed := range seeds {
f.Add(seed.Data)
}

extractor := NewDiscriminatorExtractor()
f.Fuzz(func(t *testing.T, testString string) {
// Extractor doesn't validate padding, newlines, or tabs
if len(testString) < 12 ||
strings.Contains(testString, "\n") ||
strings.Contains(testString, "\r") ||
strings.Contains(testString, "\t") ||
strings.HasSuffix(testString, "=") ||
strings.HasSuffix(testString, "==") {
return
}

stdDecoded, err := base64.StdEncoding.DecodeString(testString)
if err == nil {
require.Equal(t, stdDecoded[:8], extractor.Extract(testString))
}
})
}

func TestDiscriminatorExtractorBase64Indexes(t *testing.T) {
extractor := NewDiscriminatorExtractor()
const base64Chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
for i, c := range base64Chars {
if extractor.b64Index[c] != byte(i) {
t.Errorf("incorrect index for character %q: expected %d, got %d", c, i, extractor.b64Index[c])
}
}
}

func TestExtractor_Extract_ShortInput(t *testing.T) {
extractor := NewDiscriminatorExtractor()
defer func() {
if r := recover(); r == nil {
t.Error("expected panic for short input, but none occurred")
}
}()

// Attempt with 11-character string (needs at least 12)
extractor.Extract("short_input")
}

// Custom extractor is around 40% faster than using stdlib
func BenchmarkDiscriminatorExtraction(b *testing.B) {
generateDiscriminatorDecodeTestData := func(numTestEntries int) []string {
// corresponds to a 12 character base64 encoded string
entrySize := int64(8)
var testData []string
// Create seeded random source
r := mathrand.New(mathrand.NewSource(entrySize))
for range numTestEntries {
data := make([]byte, entrySize)
_, _ = r.Read(data)

testData = append(testData, base64.StdEncoding.EncodeToString(data))
}

return testData
}

b.Run("Standard lib Base64", func(b *testing.B) {
testData := generateDiscriminatorDecodeTestData(b.N)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, _ = base64.StdEncoding.DecodeString(testData[i])
}
})

b.Run("CustomExtractor", func(b *testing.B) {
testData := generateDiscriminatorDecodeTestData(b.N)
extractor := NewDiscriminatorExtractor()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
extractor.Extract(testData[i])
}
})
}
18 changes: 9 additions & 9 deletions pkg/solana/codec/discriminator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func TestDiscriminator(t *testing.T) {
t.Run("encode and decode return the discriminator", func(t *testing.T) {
tmp := sha256.Sum256([]byte("account:Foo"))
expected := tmp[:8]
c := codec.NewDiscriminator("Foo")
c := codec.NewDiscriminator("Foo", true)
encoded, err := c.Encode(&expected, nil)
require.NoError(t, err)
require.Equal(t, expected, encoded)
Expand All @@ -28,15 +28,15 @@ func TestDiscriminator(t *testing.T) {
})

t.Run("encode returns an error if the discriminator is invalid", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
c := codec.NewDiscriminator("Foo", true)
_, err := c.Encode(&[]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, nil)
require.True(t, errors.Is(err, types.ErrInvalidType))
})

t.Run("encode injects the discriminator if it's not provided", func(t *testing.T) {
tmp := sha256.Sum256([]byte("account:Foo"))
expected := tmp[:8]
c := codec.NewDiscriminator("Foo")
c := codec.NewDiscriminator("Foo", true)
encoded, err := c.Encode(nil, nil)
require.NoError(t, err)
require.Equal(t, expected, encoded)
Expand All @@ -46,37 +46,37 @@ func TestDiscriminator(t *testing.T) {
})

t.Run("decode returns an error if the encoded value is too short", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
c := codec.NewDiscriminator("Foo", true)
_, _, err := c.Decode([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06})
require.True(t, errors.Is(err, types.ErrInvalidEncoding))
})

t.Run("decode returns an error if the discriminator is invalid", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
c := codec.NewDiscriminator("Foo", true)
_, _, err := c.Decode([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07})
require.True(t, errors.Is(err, types.ErrInvalidEncoding))
})

t.Run("encode returns an error if the value is not a byte slice", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
c := codec.NewDiscriminator("Foo", true)
_, err := c.Encode(42, nil)
require.True(t, errors.Is(err, types.ErrInvalidType))
})

t.Run("GetType returns the type of the discriminator", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
c := codec.NewDiscriminator("Foo", true)
require.Equal(t, reflect.TypeOf(&[]byte{}), c.GetType())
})

t.Run("Size returns the length of the discriminator", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
c := codec.NewDiscriminator("Foo", true)
size, err := c.Size(0)
require.NoError(t, err)
require.Equal(t, 8, size)
})

t.Run("FixedSize returns the length of the discriminator", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
c := codec.NewDiscriminator("Foo", true)
size, err := c.FixedSize()
require.NoError(t, err)
require.Equal(t, 8, size)
Expand Down
Loading
Loading