Skip to content

Commit

Permalink
Improve log poller codec usage and optimise log event discriminator m…
Browse files Browse the repository at this point in the history
…atching (#1014)

* Fix LP codec usage and associated codec cleanup

* Optimise log poller discriminator comparison

* Remove unnecessary comment in MatchingFiltersForEncodedEvent

* lint

* lint discriminator extractor tests

* Fix fuzz test for discriminator extractor

* lint and logging

* Change filtersI DecodeSubKey to include logger and run make generate
  • Loading branch information
ilija42 authored Jan 25, 2025
1 parent 22e6a53 commit a855d4d
Show file tree
Hide file tree
Showing 16 changed files with 450 additions and 207 deletions.
45 changes: 30 additions & 15 deletions pkg/solana/codec/codec_entry.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,16 @@ func NewAccountEntry(offchainName string, idlTypes AccountIDLTypes, includeDiscr
return nil, err
}

var discriminator *Discriminator
if includeDiscriminator {
discriminator = NewDiscriminator(idlTypes.Account.Name, true)
}

return newEntry(
offchainName,
idlTypes.Account.Name,
accCodec,
includeDiscriminator,
discriminator,
mod,
), nil
}
Expand All @@ -64,7 +69,7 @@ func NewPDAEntry(offchainName string, pdaTypeDef PDATypeDef, mod codec.Modifier,
offchainName,
offchainName, // PDA seeds do not correlate to anything on-chain so reusing offchain name
accCodec,
false,
nil,
mod,
), nil
}
Expand All @@ -85,7 +90,7 @@ func NewInstructionArgsEntry(offChainName string, idlTypes InstructionArgsIDLTyp
idlTypes.Instruction.Name,
instructionCodecArgs,
// Instruction arguments don't need a discriminator by default
false,
nil,
mod,
), nil
}
Expand All @@ -101,30 +106,40 @@ func NewEventArgsEntry(offChainName string, idlTypes EventIDLTypes, includeDiscr
return nil, err
}

var discriminator *Discriminator
if includeDiscriminator {
discriminator = NewDiscriminator(idlTypes.Event.Name, false)
}

return newEntry(
offChainName,
idlTypes.Event.Name,
eventCodec,
includeDiscriminator,
discriminator,
mod,
), nil
}

func newEntry(
genericName, chainSpecificName string,
typeCodec commonencodings.TypeCodec,
includeDiscriminator bool,
discriminator *Discriminator,
mod codec.Modifier,
) Entry {
return &entry{
genericName: genericName,
chainSpecificName: chainSpecificName,
reflectType: typeCodec.GetType(),
typeCodec: typeCodec,
mod: ensureModifier(mod),
includeDiscriminator: includeDiscriminator,
discriminator: *NewDiscriminator(chainSpecificName),
e := &entry{
genericName: genericName,
chainSpecificName: chainSpecificName,
reflectType: typeCodec.GetType(),
typeCodec: typeCodec,
mod: ensureModifier(mod),
}

if discriminator != nil {
e.discriminator = *discriminator
e.includeDiscriminator = true
}

return e
}

func createRefs(idlTypes IdlTypeDefSlice, builder commonencodings.Builder) *codecRefs {
Expand Down Expand Up @@ -175,8 +190,8 @@ func (e *entry) Decode(encoded []byte) (any, []byte, error) {
}

if !bytes.Equal(e.discriminator.hashPrefix, encoded[:discriminatorLength]) {
return nil, nil, fmt.Errorf("%w: encoded data has a bad discriminator %v for genericName: %q, chainSpecificName: %q",
commontypes.ErrInvalidType, encoded[:discriminatorLength], e.genericName, e.chainSpecificName)
return nil, nil, fmt.Errorf("%w: encoded data has a bad discriminator %v, expected %v, for genericName: %q, chainSpecificName: %q",
commontypes.ErrInvalidType, encoded[:discriminatorLength], e.discriminator.hashPrefix, e.genericName, e.chainSpecificName)
}

encoded = encoded[discriminatorLength:]
Expand Down
20 changes: 14 additions & 6 deletions pkg/solana/codec/codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,26 @@ func (it *codecInterfaceTester) GetAccountString(i int) string {
}

func (it *codecInterfaceTester) EncodeFields(t *testing.T, request *EncodeRequest) []byte {
if request.TestOn == TestItemType || request.TestOn == testutils.TestEventItem {
return encodeFieldsOnItem(t, request)
if request.TestOn == TestItemType {
return encodeFieldsOnItem(t, request, true)
} else if request.TestOn == testutils.TestEventItem {
return encodeFieldsOnItem(t, request, false)
}

return encodeFieldsOnSliceOrArray(t, request)
}

func encodeFieldsOnItem(t *testing.T, request *EncodeRequest) ocr2types.Report {
func encodeFieldsOnItem(t *testing.T, request *EncodeRequest, isAccount bool) ocr2types.Report {
buf := new(bytes.Buffer)
// The underlying TestItemAsAccount adds a discriminator by default while being Borsh encoded.
if err := testutils.EncodeRequestToTestItemAsAccount(request.TestStructs[0]).MarshalWithEncoder(bin.NewBorshEncoder(buf)); err != nil {
require.NoError(t, err)
// The underlying TestItem adds a discriminator by default while being Borsh encoded.
if isAccount {
if err := testutils.EncodeRequestToTestItemAsAccount(request.TestStructs[0]).MarshalWithEncoder(bin.NewBorshEncoder(buf)); err != nil {
require.NoError(t, err)
}
} else {
if err := testutils.EncodeRequestToTestItemAsEvent(request.TestStructs[0]).MarshalWithEncoder(bin.NewBorshEncoder(buf)); err != nil {
require.NoError(t, err)
}
}
return buf.Bytes()
}
Expand Down
60 changes: 57 additions & 3 deletions pkg/solana/codec/discriminator.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,19 @@ import (

const discriminatorLength = 8

func NewDiscriminator(name string) *Discriminator {
sum := sha256.Sum256([]byte("account:" + name))
return &Discriminator{hashPrefix: sum[:discriminatorLength]}
func NewDiscriminator(name string, isAccount bool) *Discriminator {
return &Discriminator{hashPrefix: NewDiscriminatorHashPrefix(name, isAccount)}
}

func NewDiscriminatorHashPrefix(name string, isAccount bool) []byte {
var sum [32]byte
if isAccount {
sum = sha256.Sum256([]byte("account:" + name))
} else {
sum = sha256.Sum256([]byte("event:" + name))
}

return sum[:discriminatorLength]
}

type Discriminator struct {
Expand Down Expand Up @@ -69,3 +79,47 @@ func (d Discriminator) Size(_ int) (int, error) {
func (d Discriminator) FixedSize() (int, error) {
return discriminatorLength, nil
}

type DiscriminatorExtractor struct {
b64Index [128]byte
}

// NewDiscriminatorExtractor is optimised to extract discriminators from base64 encoded strings faster than the base64 lib.
func NewDiscriminatorExtractor() DiscriminatorExtractor {
instance := DiscriminatorExtractor{}
const base64Chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
for i := 0; i < len(base64Chars); i++ {
instance.b64Index[base64Chars[i]] = byte(i)
}
return instance
}

// Extract most optimally (around 40% faster than std) decodes the first 8 bytes of a base64 encoded string, which corresponds to a Solana discriminator.
// Extract expects input of > 12 characters which 8 bytes are extracted from, if the input string is less than 12 characters, this will panic.
// Extract doesn't handle base64 padding because discriminators shouldn't have padding.
// If string contains non-Base64 characters (e.g., !, @, space) map to index 0 (ASCII 'A'), and won't be accurate.
func (e *DiscriminatorExtractor) Extract(data string) []byte {
var decodeBuffer [9]byte
d := decodeBuffer[:9]
s := data[:12]

// base64 decode
for i := 0; i < 3; i++ {
// decode base64 chars into associated byte
c1 := e.b64Index[s[0]]
c2 := e.b64Index[s[1]]
c3 := e.b64Index[s[2]]
c4 := e.b64Index[s[3]]

// reconstruct raw bytes
d[0] = (c1 << 2) | (c2 >> 4)
d[1] = (c2 << 4) | (c3 >> 2)
d[2] = (c3 << 6) | c4

// next 3 bytes and next 4 characters
d = d[3:]
s = s[4:]
}

return decodeBuffer[:discriminatorLength]
}
104 changes: 104 additions & 0 deletions pkg/solana/codec/discriminator_extractor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package codec

import (
"encoding/base64"
mathrand "math/rand"
"strings"
"testing"

"github.com/stretchr/testify/require"
)

func FuzzExtractorHappyPath(f *testing.F) {
// Seed with valid base64 discriminators
seeds := []struct {
Data string
}{
{"SGVsbG8gV29ybGQh"}, // Hello world!
{"AAAAAAAAAAAA"}, // Zero bytes
{"////////////"}, // Max value bytes
{"QUJDREVGR0hJSktM"}, // ABCDEFGHIJKL
}

for _, seed := range seeds {
f.Add(seed.Data)
}

extractor := NewDiscriminatorExtractor()
f.Fuzz(func(t *testing.T, testString string) {
// Extractor doesn't validate padding, newlines, or tabs
if len(testString) < 12 ||
strings.Contains(testString, "\n") ||
strings.Contains(testString, "\r") ||
strings.Contains(testString, "\t") ||
strings.HasSuffix(testString, "=") ||
strings.HasSuffix(testString, "==") {
return
}

stdDecoded, err := base64.StdEncoding.DecodeString(testString)
if err == nil {
require.Equal(t, stdDecoded[:8], extractor.Extract(testString))
}
})
}

func TestDiscriminatorExtractorBase64Indexes(t *testing.T) {
extractor := NewDiscriminatorExtractor()
const base64Chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
for i, c := range base64Chars {
if extractor.b64Index[c] != byte(i) {
t.Errorf("incorrect index for character %q: expected %d, got %d", c, i, extractor.b64Index[c])
}
}
}

func TestExtractor_Extract_ShortInput(t *testing.T) {
extractor := NewDiscriminatorExtractor()
defer func() {
if r := recover(); r == nil {
t.Error("expected panic for short input, but none occurred")
}
}()

// Attempt with 11-character string (needs at least 12)
extractor.Extract("short_input")
}

// Custom extractor is around 40% faster than using stdlib
func BenchmarkDiscriminatorExtraction(b *testing.B) {
generateDiscriminatorDecodeTestData := func(numTestEntries int) []string {
// corresponds to a 12 character base64 encoded string
entrySize := int64(8)
var testData []string
// Create seeded random source
r := mathrand.New(mathrand.NewSource(entrySize))
for range numTestEntries {
data := make([]byte, entrySize)
_, _ = r.Read(data)

testData = append(testData, base64.StdEncoding.EncodeToString(data))
}

return testData
}

b.Run("Standard lib Base64", func(b *testing.B) {
testData := generateDiscriminatorDecodeTestData(b.N)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, _ = base64.StdEncoding.DecodeString(testData[i])
}
})

b.Run("CustomExtractor", func(b *testing.B) {
testData := generateDiscriminatorDecodeTestData(b.N)
extractor := NewDiscriminatorExtractor()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
extractor.Extract(testData[i])
}
})
}
18 changes: 9 additions & 9 deletions pkg/solana/codec/discriminator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func TestDiscriminator(t *testing.T) {
t.Run("encode and decode return the discriminator", func(t *testing.T) {
tmp := sha256.Sum256([]byte("account:Foo"))
expected := tmp[:8]
c := codec.NewDiscriminator("Foo")
c := codec.NewDiscriminator("Foo", true)
encoded, err := c.Encode(&expected, nil)
require.NoError(t, err)
require.Equal(t, expected, encoded)
Expand All @@ -28,15 +28,15 @@ func TestDiscriminator(t *testing.T) {
})

t.Run("encode returns an error if the discriminator is invalid", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
c := codec.NewDiscriminator("Foo", true)
_, err := c.Encode(&[]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, nil)
require.True(t, errors.Is(err, types.ErrInvalidType))
})

t.Run("encode injects the discriminator if it's not provided", func(t *testing.T) {
tmp := sha256.Sum256([]byte("account:Foo"))
expected := tmp[:8]
c := codec.NewDiscriminator("Foo")
c := codec.NewDiscriminator("Foo", true)
encoded, err := c.Encode(nil, nil)
require.NoError(t, err)
require.Equal(t, expected, encoded)
Expand All @@ -46,37 +46,37 @@ func TestDiscriminator(t *testing.T) {
})

t.Run("decode returns an error if the encoded value is too short", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
c := codec.NewDiscriminator("Foo", true)
_, _, err := c.Decode([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06})
require.True(t, errors.Is(err, types.ErrInvalidEncoding))
})

t.Run("decode returns an error if the discriminator is invalid", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
c := codec.NewDiscriminator("Foo", true)
_, _, err := c.Decode([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07})
require.True(t, errors.Is(err, types.ErrInvalidEncoding))
})

t.Run("encode returns an error if the value is not a byte slice", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
c := codec.NewDiscriminator("Foo", true)
_, err := c.Encode(42, nil)
require.True(t, errors.Is(err, types.ErrInvalidType))
})

t.Run("GetType returns the type of the discriminator", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
c := codec.NewDiscriminator("Foo", true)
require.Equal(t, reflect.TypeOf(&[]byte{}), c.GetType())
})

t.Run("Size returns the length of the discriminator", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
c := codec.NewDiscriminator("Foo", true)
size, err := c.Size(0)
require.NoError(t, err)
require.Equal(t, 8, size)
})

t.Run("FixedSize returns the length of the discriminator", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
c := codec.NewDiscriminator("Foo", true)
size, err := c.FixedSize()
require.NoError(t, err)
require.Equal(t, 8, size)
Expand Down
Loading

0 comments on commit a855d4d

Please sign in to comment.