diff --git a/consensus/bor/api.go b/consensus/bor/api.go index 26d1efdaf115..6d72e309e3c5 100644 --- a/consensus/bor/api.go +++ b/consensus/bor/api.go @@ -75,13 +75,34 @@ func rankMapDifficulties(values map[common.Address]uint64) []difficultiesKV { } // GetSnapshotProposerSequence retrieves the in-turn signers of all sprints in a span -func (api *API) GetSnapshotProposerSequence(number *rpc.BlockNumber) (BlockSigners, error) { - snapNumber := *number - 1 +func (api *API) GetSnapshotProposerSequence(blockNrOrHash *rpc.BlockNumberOrHash) (BlockSigners, error) { + var header *types.Header + //nolint:nestif + if blockNrOrHash == nil { + header = api.chain.CurrentHeader() + } else { + if blockNr, ok := blockNrOrHash.Number(); ok { + if blockNr == rpc.LatestBlockNumber { + header = api.chain.CurrentHeader() + } else { + header = api.chain.GetHeaderByNumber(uint64(blockNr)) + } + } else { + if blockHash, ok := blockNrOrHash.Hash(); ok { + header = api.chain.GetHeaderByHash(blockHash) + } + } + } - var difficulties = make(map[common.Address]uint64) + if header == nil { + return BlockSigners{}, errUnknownBlock + } + snapNumber := rpc.BlockNumber(header.Number.Int64() - 1) snap, err := api.GetSnapshot(&snapNumber) + var difficulties = make(map[common.Address]uint64) + if err != nil { return BlockSigners{}, err } @@ -101,7 +122,7 @@ func (api *API) GetSnapshotProposerSequence(number *rpc.BlockNumber) (BlockSigne rankedDifficulties := rankMapDifficulties(difficulties) - author, err := api.GetAuthor(number) + author, err := api.GetAuthor(blockNrOrHash) if err != nil { return BlockSigners{}, err } @@ -117,9 +138,31 @@ func (api *API) GetSnapshotProposerSequence(number *rpc.BlockNumber) (BlockSigne } // GetSnapshotProposer retrieves the in-turn signer at a given block. -func (api *API) GetSnapshotProposer(number *rpc.BlockNumber) (common.Address, error) { - *number -= 1 - snap, err := api.GetSnapshot(number) +func (api *API) GetSnapshotProposer(blockNrOrHash *rpc.BlockNumberOrHash) (common.Address, error) { + var header *types.Header + //nolint:nestif + if blockNrOrHash == nil { + header = api.chain.CurrentHeader() + } else { + if blockNr, ok := blockNrOrHash.Number(); ok { + if blockNr == rpc.LatestBlockNumber { + header = api.chain.CurrentHeader() + } else { + header = api.chain.GetHeaderByNumber(uint64(blockNr)) + } + } else { + if blockHash, ok := blockNrOrHash.Hash(); ok { + header = api.chain.GetHeaderByHash(blockHash) + } + } + } + + if header == nil { + return common.Address{}, errUnknownBlock + } + + snapNumber := rpc.BlockNumber(header.Number.Int64() - 1) + snap, err := api.GetSnapshot(&snapNumber) if err != nil { return common.Address{}, err @@ -129,14 +172,26 @@ func (api *API) GetSnapshotProposer(number *rpc.BlockNumber) (common.Address, er } // GetAuthor retrieves the author a block. -func (api *API) GetAuthor(number *rpc.BlockNumber) (*common.Address, error) { +func (api *API) GetAuthor(blockNrOrHash *rpc.BlockNumberOrHash) (*common.Address, error) { // Retrieve the requested block number (or current if none requested) var header *types.Header - if number == nil || *number == rpc.LatestBlockNumber { + + //nolint:nestif + if blockNrOrHash == nil { header = api.chain.CurrentHeader() } else { - header = api.chain.GetHeaderByNumber(uint64(number.Int64())) + if blockNr, ok := blockNrOrHash.Number(); ok { + header = api.chain.GetHeaderByNumber(uint64(blockNr)) + if blockNr == rpc.LatestBlockNumber { + header = api.chain.CurrentHeader() + } + } else { + if blockHash, ok := blockNrOrHash.Hash(); ok { + header = api.chain.GetHeaderByHash(blockHash) + } + } } + // Ensure we have an actually valid block and return its snapshot if header == nil { return nil, errUnknownBlock