Skip to content

Commit

Permalink
Merge pull request #200 from nkryuchkov/fix/remove-panics
Browse files Browse the repository at this point in the history
Remove some panics
  • Loading branch information
志宇 authored Mar 6, 2020
2 parents 37c2045 + 80f373a commit 857d105
Show file tree
Hide file tree
Showing 17 changed files with 184 additions and 69 deletions.
7 changes: 3 additions & 4 deletions pkg/app/log_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func newBoltDB(path, appName string) (_ LogStore, err error) {
}

// Write implements io.Writer
func (l *boltDBappLogs) Write(p []byte) (int, error) {
func (l *boltDBappLogs) Write(p []byte) (n int, err error) {
// ensure there is at least timestamp long bytes
if len(p) < 37 {
return 0, io.ErrShortBuffer
Expand All @@ -79,9 +79,8 @@ func (l *boltDBappLogs) Write(p []byte) (int, error) {
}

defer func() {
err := db.Close()
if err != nil {
panic(err)
if closeErr := db.Close(); err == nil {
err = closeErr
}
}()

Expand Down
5 changes: 4 additions & 1 deletion pkg/router/route_group.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,10 @@ func (rg *RouteGroup) read(p []byte) (int, error) {
}

func (rg *RouteGroup) write(data []byte, tp *transport.ManagedTransport, rule routing.Rule) (int, error) {
packet := routing.MakeDataPacket(rule.NextRouteID(), data)
packet, err := routing.MakeDataPacket(rule.NextRouteID(), data)
if err != nil {
return 0, err
}

rg.logger.Debugf("Writing packet of type %s, route ID %d and next ID %d", packet.Type(),
rule.KeyRouteID(), rule.NextRouteID())
Expand Down
8 changes: 7 additions & 1 deletion pkg/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -553,9 +553,15 @@ func (r *router) forwardPacket(ctx context.Context, packet routing.Packet, rule
}

var p routing.Packet

switch packet.Type() {
case routing.DataPacket:
p = routing.MakeDataPacket(rule.NextRouteID(), packet.Payload())
var err error

p, err = routing.MakeDataPacket(rule.NextRouteID(), packet.Payload())
if err != nil {
return err
}
case routing.KeepAlivePacket:
p = routing.MakeKeepAlivePacket(rule.NextRouteID())
case routing.ClosePacket:
Expand Down
16 changes: 12 additions & 4 deletions pkg/router/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,9 @@ func testForwardRule(t *testing.T, r0, r1 *router, tp1 *transport.ManagedTranspo
r0.saveRouteGroupRules(routing.EdgeRules{Desc: fwdRule.RouteDescriptor(), Forward: fwdRule, Reverse: nil})

// Call handleTransportPacket for r0 (this should in turn, use the rule we added).
packet := routing.MakeDataPacket(fwdRtID[0], []byte("This is a test!"))
packet, err := routing.MakeDataPacket(fwdRtID[0], []byte("This is a test!"))
require.NoError(t, err)

require.NoError(t, r0.handleTransportPacket(context.TODO(), packet))

// r1 should receive the packet handled by r0.
Expand All @@ -457,7 +459,9 @@ func testIntermediaryForwardRule(t *testing.T, r0, r1 *router, tp1 *transport.Ma
require.NoError(t, err)

// Call handleTransportPacket for r0 (this should in turn, use the rule we added).
packet := routing.MakeDataPacket(fwdRtID[0], []byte("This is a test!"))
packet, err := routing.MakeDataPacket(fwdRtID[0], []byte("This is a test!"))
require.NoError(t, err)

require.NoError(t, r0.handleTransportPacket(context.TODO(), packet))

// r1 should receive the packet handled by r0.
Expand Down Expand Up @@ -501,7 +505,9 @@ func testConsumeRule(t *testing.T, r0, r1 *router, tp1 *transport.ManagedTranspo
Reverse: cnsmRule,
})

packet := routing.MakeDataPacket(intFwdRtID[0], []byte("test intermediary forward"))
packet, err := routing.MakeDataPacket(intFwdRtID[0], []byte("test intermediary forward"))
require.NoError(t, err)

require.NoError(t, r0.handleTransportPacket(context.TODO(), packet))

recvPacket, err := r1.tm.ReadPacket()
Expand All @@ -511,7 +517,9 @@ func testConsumeRule(t *testing.T, r0, r1 *router, tp1 *transport.ManagedTranspo
assert.Equal(t, dstRtIDs[1], recvPacket.RouteID())

consumeMsg := []byte("test_consume")
packet = routing.MakeDataPacket(dstRtIDs[1], consumeMsg)
packet, err = routing.MakeDataPacket(dstRtIDs[1], consumeMsg)
require.NoError(t, err)

require.NoError(t, r1.handleTransportPacket(context.TODO(), packet))

rg, ok := r1.routeGroup(fwdRtDesc.Invert())
Expand Down
14 changes: 10 additions & 4 deletions pkg/routing/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package routing

import (
"encoding/binary"
"errors"
"fmt"
"math"
)
Expand All @@ -24,6 +25,11 @@ const (
PacketPayloadOffset = PacketHeaderSize
)

var (
// ErrPayloadTooBig is returned when passed payload is too big (more than math.MaxUint16).
ErrPayloadTooBig = errors.New("packet size exceeded")
)

// PacketType represents packet purpose.
type PacketType byte

Expand Down Expand Up @@ -71,10 +77,10 @@ const (
type RouteID uint32

// MakeDataPacket constructs a new DataPacket.
// If payload size is more than uint16, MakeDataPacket will panic.
func MakeDataPacket(id RouteID, payload []byte) Packet {
// If payload size is more than uint16, MakeDataPacket returns an error.
func MakeDataPacket(id RouteID, payload []byte) (Packet, error) {
if len(payload) > math.MaxUint16 {
panic("packet size exceeded")
return Packet{}, ErrPayloadTooBig
}

packet := make([]byte, PacketHeaderSize+len(payload))
Expand All @@ -84,7 +90,7 @@ func MakeDataPacket(id RouteID, payload []byte) Packet {
binary.BigEndian.PutUint16(packet[PacketPayloadSizeOffset:], uint16(len(payload)))
copy(packet[PacketPayloadOffset:], payload)

return packet
return packet, nil
}

// MakeClosePacket constructs a new ClosePacket.
Expand Down
5 changes: 4 additions & 1 deletion pkg/routing/packet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestMakeDataPacket(t *testing.T) {
packet := MakeDataPacket(2, []byte("foo"))
packet, err := MakeDataPacket(2, []byte("foo"))
require.NoError(t, err)

expected := []byte{0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x3, 0x66, 0x6f, 0x6f}

assert.Equal(t, expected, []byte(packet))
Expand Down
9 changes: 8 additions & 1 deletion pkg/snet/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func NewRaw(conf Config, dmsgC *dmsg.Client, stcpC *stcp.Client) *Network {
}

// Init initiates server connections.
func (n *Network) Init(ctx context.Context) error {
func (n *Network) Init(_ context.Context) error {
if n.dmsgC != nil {
time.Sleep(200 * time.Millisecond)
go n.dmsgC.Serve()
Expand Down Expand Up @@ -175,6 +175,7 @@ func (n *Network) Dial(ctx context.Context, network string, pk cipher.PubKey, po
if err != nil {
return nil, err
}

return makeConn(conn, network), nil
default:
return nil, ErrUnknownNetwork
Expand All @@ -189,12 +190,14 @@ func (n *Network) Listen(network string, port uint16) (*Listener, error) {
if err != nil {
return nil, err
}

return makeListener(lis, network), nil
case STcpType:
lis, err := n.stcpC.Listen(port)
if err != nil {
return nil, err
}

return makeListener(lis, network), nil
default:
return nil, ErrUnknownNetwork
Expand Down Expand Up @@ -229,6 +232,7 @@ func (l Listener) AcceptConn() (*Conn, error) {
if err != nil {
return nil, err
}

return makeConn(conn, l.network), nil
}

Expand Down Expand Up @@ -268,13 +272,16 @@ func disassembleAddr(addr net.Addr) (pk cipher.PubKey, port uint16) {
if len(strs) != 2 {
panic(fmt.Errorf("network.disassembleAddr: %v %s", "invalid addr", addr.String()))
}

if err := pk.Set(strs[0]); err != nil {
panic(fmt.Errorf("network.disassembleAddr: %v %s", err, addr.String()))
}

if strs[1] != "~" {
if _, err := fmt.Sscanf(strs[1], "%d", &port); err != nil {
panic(fmt.Errorf("network.disassembleAddr: %v", err))
}
}

return
}
1 change: 1 addition & 0 deletions pkg/snet/network_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ func TestDisassembleAddr(t *testing.T) {
addr := dmsg.Addr{
PK: pk, Port: port,
}

gotPK, gotPort := disassembleAddr(addr)
require.Equal(t, pk, gotPK)
require.Equal(t, port, gotPort)
Expand Down
2 changes: 2 additions & 0 deletions pkg/snet/snettest/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ func GenKeyPairs(n int) []KeyPair {
if err != nil {
panic(err)
}

pairs[i] = KeyPair{PK: pk, SK: sk}
}

return pairs
}

Expand Down
37 changes: 26 additions & 11 deletions pkg/transport/entry.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
package transport

import (
"errors"
"fmt"
"strings"

"github.com/SkycoinProject/dmsg/cipher"
"github.com/google/uuid"
)

var (
// ErrEdgeIndexNotFound is returned when no edge index was found.
ErrEdgeIndexNotFound = errors.New("edge index not found")
)

// Entry is the unsigned representation of a Transport.
type Entry struct {

Expand Down Expand Up @@ -59,6 +65,7 @@ func (e *Entry) EdgeIndex(pk cipher.PubKey) int {
return i
}
}

return -1
}

Expand Down Expand Up @@ -99,12 +106,13 @@ func (e *Entry) ToBinary() []byte {

// Signature returns signature for Entry calculated from binary
// representation.
func (e *Entry) Signature(secKey cipher.SecKey) cipher.Sig {
func (e *Entry) Signature(secKey cipher.SecKey) (cipher.Sig, error) {
sig, err := cipher.SignPayload(e.ToBinary(), secKey)
if err != nil {
panic(err)
return cipher.Sig{}, err
}
return sig

return sig, nil
}

// SignedEntry holds an Entry and it's associated signatures.
Expand All @@ -116,27 +124,34 @@ type SignedEntry struct {
}

// Sign sets Signature for a given PubKey in correct position
func (se *SignedEntry) Sign(pk cipher.PubKey, secKey cipher.SecKey) bool {
func (se *SignedEntry) Sign(pk cipher.PubKey, secKey cipher.SecKey) error {
idx := se.Entry.EdgeIndex(pk)
if idx == -1 {
return false
return ErrEdgeIndexNotFound
}
se.Signatures[idx] = se.Entry.Signature(secKey)

return true
sig, err := se.Entry.Signature(secKey)
if err != nil {
return err
}

se.Signatures[idx] = sig

return nil
}

// Signature gets Signature for a given PubKey from correct position
func (se *SignedEntry) Signature(pk cipher.PubKey) (cipher.Sig, bool) {
func (se *SignedEntry) Signature(pk cipher.PubKey) (cipher.Sig, error) {
idx := se.Entry.EdgeIndex(pk)
if idx == -1 {
return cipher.Sig{}, false
return cipher.Sig{}, ErrEdgeIndexNotFound
}
return se.Signatures[idx], true

return se.Signatures[idx], nil
}

// NewSignedEntry creates a SignedEntry with first signature
func NewSignedEntry(entry *Entry, pk cipher.PubKey, secKey cipher.SecKey) (*SignedEntry, bool) {
func NewSignedEntry(entry *Entry, pk cipher.PubKey, secKey cipher.SecKey) (*SignedEntry, error) {
se := &SignedEntry{Entry: entry}
return se, se.Sign(pk, secKey)
}
Expand Down
29 changes: 16 additions & 13 deletions pkg/transport/entry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,17 @@ func ExampleSignedEntry_Sign() {
fmt.Println("No signatures set")
}

if ok := sEntry.Sign(pkA, skA); !ok {
fmt.Println("error signing with skA")
if err := sEntry.Sign(pkA, skA); err != nil {
fmt.Println("error signing with skA: ", err)
}

if (!sEntry.Signatures[0].Null() && sEntry.Signatures[1].Null()) ||
(!sEntry.Signatures[1].Null() && sEntry.Signatures[0].Null()) {
fmt.Println("One signature set")
}

if ok := sEntry.Sign(pkB, skB); !ok {
fmt.Println("error signing with skB")
if err := sEntry.Sign(pkB, skB); err != nil {
fmt.Println("error signing with skB: ", err)
}

if !sEntry.Signatures[0].Null() && !sEntry.Signatures[1].Null() {
Expand All @@ -79,35 +80,37 @@ func ExampleSignedEntry_Signature() {

entry := transport.NewEntry(pkA, pkB, "mock", true)
sEntry := &transport.SignedEntry{Entry: entry}
if ok := sEntry.Sign(pkA, skA); !ok {

if err := sEntry.Sign(pkA, skA); err != nil {
fmt.Println("Error signing sEntry with (pkA,skA)")
}
if ok := sEntry.Sign(pkB, skB); !ok {

if err := sEntry.Sign(pkB, skB); err != nil {
fmt.Println("Error signing sEntry with (pkB,skB)")
}

idxA := sEntry.Entry.EdgeIndex(pkA)
idxB := sEntry.Entry.EdgeIndex(pkB)

sigA, okA := sEntry.Signature(pkA)
sigB, okB := sEntry.Signature(pkB)
sigA, errA := sEntry.Signature(pkA)
sigB, errB := sEntry.Signature(pkB)

if okA && sigA == sEntry.Signatures[idxA] {
if errA == nil && sigA == sEntry.Signatures[idxA] {
fmt.Println("SignatureA got")
}

if okB && (sigB == sEntry.Signatures[idxB]) {
if errB == nil && (sigB == sEntry.Signatures[idxB]) {
fmt.Println("SignatureB got")
}

// Incorrect case
pkC, _ := cipher.GenerateKeyPair()
if _, ok := sEntry.Signature(pkC); !ok {
fmt.Printf("SignatureC got error: invalid pubkey")
if _, err := sEntry.Signature(pkC); err != nil {
fmt.Printf("SignatureC got error: %v\n", err)
}

//
// Output: SignatureA got
// SignatureB got
// SignatureC got error: invalid pubkey
// SignatureC got error: edge index not found
}
Loading

0 comments on commit 857d105

Please sign in to comment.