Skip to content

Commit

Permalink
Validate exact expected error in signed header verification tests (#1…
Browse files Browse the repository at this point in the history
…165)

<!--
Please read and fill out this form before submitting your PR.

Please make sure you have reviewed our contributors guide before
submitting your
first PR.
-->

## Overview
Closes: cosmos#1049

Stacked on top of cosmos#1162

<!-- 
Please provide an explanation of the PR, including the appropriate
context,
background, goal, and rationale. If there is an issue with this
information,
please provide a tl;dr and link the issue. 
-->

## Checklist

<!-- 
Please complete the checklist to ensure that the PR is ready to be
reviewed.

IMPORTANT:
PRs should be left in Draft until the below checklist is completed.
-->

- [x] New and updated code has appropriate documentation
- [x] New and updated code has new and/or updated testing
- [x] Required CI checks are passing
- [ ] Visual proof for any user facing features like CLI or
documentation updates
- [x] Linked issues closed with keywords

---------

Co-authored-by: Matthew Sevey <[email protected]>
  • Loading branch information
Manav-Aggarwal and MSevey authored Sep 9, 2023
1 parent bd7664c commit 472f245
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 30 deletions.
21 changes: 13 additions & 8 deletions types/header.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package types

import (
"encoding"
"errors"
"fmt"
"time"

Expand Down Expand Up @@ -85,7 +84,7 @@ func (h *Header) Verify(untrst header.Header) error {
}
// sanity check fields
if err := verifyNewHeaderAndVals(h, untrstH); err != nil {
return &header.VerifyError{Reason: err}
return err
}

// Check the validator hashes are the same in the case headers are adjacent
Expand Down Expand Up @@ -132,16 +131,22 @@ func verifyNewHeaderAndVals(trusted, untrusted *Header) error {
}

if !untrusted.Time().After(trusted.Time()) {
return fmt.Errorf("expected new header time %v to be after old header time %v",
untrusted.Time(),
trusted.Time())
return fmt.Errorf("%w: %w",
ErrNewHeaderTimeBeforeOldHeaderTime,
fmt.Errorf("expected new header time %v to be after %v",
untrusted.Time(),
trusted.Time(),
),
)
}

if !untrusted.Time().Before(time.Now().Add(maxClockDrift)) {
return fmt.Errorf("new header has a time from the future %v (now: %v; max clock drift: %v)",
return fmt.Errorf("%w: new header time %v (now: %v; max clock drift: %v)",
ErrNewHeaderTimeFromFuture,
untrusted.Time(),
time.Now(),
maxClockDrift)
maxClockDrift,
)
}

return nil
Expand All @@ -150,7 +155,7 @@ func verifyNewHeaderAndVals(trusted, untrusted *Header) error {
// ValidateBasic performs basic validation of a header.
func (h *Header) ValidateBasic() error {
if len(h.ProposerAddress) == 0 {
return errors.New("no proposer address")
return ErrNoProposerAddress
}

return nil
Expand Down
24 changes: 20 additions & 4 deletions types/signed_header.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ func (sH *SignedHeader) IsZero() bool {
return sH == nil
}

var (
ErrAggregatorSetHashMismatch = errors.New("aggregator set hash in signed header and hash of validator set do not match")
ErrSignatureVerificationFailed = errors.New("signature verification failed")
ErrNoProposerAddress = errors.New("no proposer address")
ErrLastHeaderHashMismatch = errors.New("last header hash mismatch")
ErrLastCommitHashMismatch = errors.New("last commit hash mismatch")
ErrNewHeaderTimeBeforeOldHeaderTime = errors.New("new header has time before old header time")
ErrNewHeaderTimeFromFuture = errors.New("new header has time from future")
)

func (sH *SignedHeader) Verify(untrst header.Header) error {
// Explicit type checks are required due to embedded Header which also does the explicit type check
untrstH, ok := untrst.(*SignedHeader)
Expand Down Expand Up @@ -44,13 +54,19 @@ func (sH *SignedHeader) Verify(untrst header.Header) error {
sHHash := sH.Header.Hash()
if !bytes.Equal(untrstH.LastHeaderHash[:], sHHash) {
return &header.VerifyError{
Reason: fmt.Errorf("last header hash %v does not match hash of previous header %v", untrstH.LastHeaderHash[:], sHHash),
Reason: fmt.Errorf("%w: expected %v, but got %v",
ErrLastHeaderHashMismatch,
untrstH.LastHeaderHash[:], sHHash,
),
}
}
sHLastCommitHash := sH.Commit.GetCommitHash(&untrstH.Header, sH.ProposerAddress)
if !bytes.Equal(untrstH.LastCommitHash[:], sHLastCommitHash) {
return &header.VerifyError{
Reason: fmt.Errorf("last commit hash %v does not match hash of previous header %v", untrstH.LastCommitHash[:], sHHash),
Reason: fmt.Errorf("%w: expected %v, but got %v",
ErrLastCommitHashMismatch,
untrstH.LastCommitHash[:], sHHash,
),
}
}
return nil
Expand Down Expand Up @@ -78,7 +94,7 @@ func (h *SignedHeader) ValidateBasic() error {
}

if !bytes.Equal(h.Validators.Hash(), h.AggregatorsHash[:]) {
return errors.New("aggregator set hash in signed header and hash of validator set do not match")
return ErrAggregatorSetHashMismatch
}

// Make sure there is exactly one signature
Expand All @@ -94,7 +110,7 @@ func (h *SignedHeader) ValidateBasic() error {
return errors.New("signature verification failed, unable to marshal header")
}
if !pubKey.VerifySignature(msg, signature) {
return errors.New("signature verification failed")
return ErrSignatureVerificationFailed
}

return nil
Expand Down
62 changes: 44 additions & 18 deletions types/signed_header_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"
"time"

"github.com/celestiaorg/go-header"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand All @@ -15,37 +16,46 @@ func TestVerify(t *testing.T) {
time.Sleep(time.Second)
untrustedAdj, err := GetNextRandomHeader(trusted, privKey)
require.NoError(t, err)
fakeAggregatorsHash := header.Hash(GetRandomBytes(32))
fakeLastHeaderHash := header.Hash(GetRandomBytes(32))
fakeLastCommitHash := header.Hash(GetRandomBytes(32))
tests := []struct {
prepare func() (*SignedHeader, bool)
err bool
err error
}{
{
prepare: func() (*SignedHeader, bool) { return untrustedAdj, false },
err: false,
err: nil,
},
{
prepare: func() (*SignedHeader, bool) {
untrusted := *untrustedAdj
untrusted.AggregatorsHash = GetRandomBytes(32)
return &untrusted, true
untrusted.AggregatorsHash = fakeAggregatorsHash
return &untrusted, false
},
err: &header.VerifyError{
Reason: ErrAggregatorSetHashMismatch,
},
err: true,
},
{
prepare: func() (*SignedHeader, bool) {
untrusted := *untrustedAdj
untrusted.LastHeaderHash = GetRandomBytes(32)
untrusted.LastHeaderHash = fakeLastHeaderHash
return &untrusted, true
},
err: true,
err: &header.VerifyError{
Reason: ErrLastHeaderHashMismatch,
},
},
{
prepare: func() (*SignedHeader, bool) {
untrusted := *untrustedAdj
untrusted.LastCommitHash = GetRandomBytes(32)
untrusted.LastCommitHash = fakeLastCommitHash
return &untrusted, true
},
err: true,
err: &header.VerifyError{
Reason: ErrLastCommitHashMismatch,
},
},
{
prepare: func() (*SignedHeader, bool) {
Expand All @@ -54,47 +64,57 @@ func TestVerify(t *testing.T) {
untrusted.Header.BaseHeader.Height++
return &untrusted, true
},
err: false, // Accepts non-adjacent headers
err: nil, // Accepts non-adjacent headers
},
{
prepare: func() (*SignedHeader, bool) {
untrusted := *untrustedAdj
untrusted.Header.BaseHeader.Time = uint64(untrusted.Header.Time().Truncate(time.Hour).UnixNano())
return &untrusted, true
},
err: true,
err: &header.VerifyError{
Reason: ErrNewHeaderTimeBeforeOldHeaderTime,
},
},
{
prepare: func() (*SignedHeader, bool) {
untrusted := *untrustedAdj
untrusted.Header.BaseHeader.Time = uint64(untrusted.Header.Time().Add(time.Minute).UnixNano())
return &untrusted, true
},
err: true,
err: &header.VerifyError{
Reason: ErrNewHeaderTimeFromFuture,
},
},
{
prepare: func() (*SignedHeader, bool) {
untrusted := *untrustedAdj
untrusted.BaseHeader.ChainID = "toaster"
return &untrusted, false // Signature verification should fail
},
err: true,
err: &header.VerifyError{
Reason: ErrSignatureVerificationFailed,
},
},
{
prepare: func() (*SignedHeader, bool) {
untrusted := *untrustedAdj
untrusted.Version.App = untrusted.Version.App + 1
return &untrusted, false // Signature verification should fail
},
err: true,
err: &header.VerifyError{
Reason: ErrSignatureVerificationFailed,
},
},
{
prepare: func() (*SignedHeader, bool) {
untrusted := *untrustedAdj
untrusted.ProposerAddress = nil
return &untrusted, true
},
err: true,
err: &header.VerifyError{
Reason: ErrNoProposerAddress,
},
},
}

Expand All @@ -107,11 +127,17 @@ func TestVerify(t *testing.T) {
preparedHeader.Commit = *commit
}
err = trusted.Verify(preparedHeader)
if test.err {
assert.Error(t, err)
} else {
if test.err == nil {
assert.NoError(t, err)
return
}
if err == nil {
t.Errorf("expected err: %v, got nil", test.err)
return
}
reason := err.(*header.VerifyError).Reason
testReason := test.err.(*header.VerifyError).Reason
assert.ErrorIs(t, reason, testReason)
})
}
}

0 comments on commit 472f245

Please sign in to comment.