From 73f2e908fce3eed2dbe940e36de81e812dcaf6a0 Mon Sep 17 00:00:00 2001 From: iketheadore Date: Tue, 27 Jun 2017 11:26:17 +0800 Subject: [PATCH] Fix issue #360 --- cmd/skycoin/skycoin.go | 26 +- src/api/webrpc/gatewayer_mock_test.go | 13 +- src/api/webrpc/webrpc.go | 36 +-- src/api/webrpc/webrpc_test.go | 22 +- src/daemon/daemon.go | 84 +++++-- src/daemon/gnet/pool.go | 341 +++++++++++++++----------- src/daemon/gnet/pool_test.go | 111 ++++----- src/daemon/peers.go | 4 +- src/daemon/pex/pex.go | 8 +- src/daemon/pool.go | 5 +- src/daemon/rpc.go | 24 +- src/daemon/storage.go | 49 ++-- src/daemon/storage_test.go | 16 +- src/daemon/visor.go | 45 ++-- src/visor/blockchain.go | 5 +- src/visor/blockchain_parser.go | 16 +- src/visor/visor.go | 36 +-- 17 files changed, 501 insertions(+), 340 deletions(-) diff --git a/cmd/skycoin/skycoin.go b/cmd/skycoin/skycoin.go index f1709878f0..f282f1a2ed 100644 --- a/cmd/skycoin/skycoin.go +++ b/cmd/skycoin/skycoin.go @@ -534,11 +534,16 @@ func Run(c *Config) { return } - go d.Run(quit) + errC := make(chan error, 1) + go func() { + errC <- d.Run() + }() + + var rpc *webrpc.WebRPC // start the webrpc if c.RPCInterface { - rpc, err := webrpc.New( + rpc, err = webrpc.New( fmt.Sprintf("%v:%v", c.RPCInterfaceAddr, c.RPCInterfacePort), webrpc.ChanBuffSize(1000), webrpc.ThreadNum(c.RPCThreadNum), @@ -548,7 +553,9 @@ func Run(c *Config) { return } - go rpc.Run(quit) + go func() { + errC <- rpc.Run() + }() } // Debug only - forces connection on start. Violates thread safety. @@ -623,9 +630,18 @@ func Run(c *Config) { } */ - <-quit + select { + case <-quit: + case err := <-errC: + logger.Error("%v", err) + } + + logger.Info("Shutting down...") + + if rpc != nil { + rpc.Shutdown() + } - logger.Info("Shutting down") gui.Shutdown() d.Shutdown() logger.Info("Goodbye") diff --git a/src/api/webrpc/gatewayer_mock_test.go b/src/api/webrpc/gatewayer_mock_test.go index 087431d1e9..e1c63cb599 100644 --- a/src/api/webrpc/gatewayer_mock_test.go +++ b/src/api/webrpc/gatewayer_mock_test.go @@ -152,7 +152,7 @@ func (m *GatewayerMock) GetTransaction(p0 cipher.SHA256) (*visor.Transaction, er } // GetUnspentOutputs mocked method -func (m *GatewayerMock) GetUnspentOutputs(p0 ...daemon.OutputsFilter) visor.ReadableOutputSet { +func (m *GatewayerMock) GetUnspentOutputs(p0 ...daemon.OutputsFilter) (visor.ReadableOutputSet, error) { ret := m.Called(p0) @@ -165,7 +165,16 @@ func (m *GatewayerMock) GetUnspentOutputs(p0 ...daemon.OutputsFilter) visor.Read panic(fmt.Sprintf("unexpected type: %v", res)) } - return r0 + var r1 error + switch res := ret.Get(1).(type) { + case nil: + case error: + r1 = res + default: + panic(fmt.Sprintf("unexpected type: %v", res)) + } + + return r0, r1 } diff --git a/src/api/webrpc/webrpc.go b/src/api/webrpc/webrpc.go index ed519c8e15..8560ef370c 100644 --- a/src/api/webrpc/webrpc.go +++ b/src/api/webrpc/webrpc.go @@ -117,6 +117,8 @@ type WebRPC struct { mux *http.ServeMux handlers map[string]HandlerFunc gateway Gatewayer + listener net.Listener + quit chan struct{} } // Option is the argument type for creating webrpc instance. @@ -126,6 +128,7 @@ type Option func(*WebRPC) func New(addr string, ops ...Option) (*WebRPC, error) { rpc := &WebRPC{ addr: addr, + quit: make(chan struct{}), } for _, opt := range ops { @@ -176,40 +179,37 @@ func (rpc *WebRPC) initHandlers() error { } // Run starts the webrpc service. -func (rpc *WebRPC) Run(quit chan struct{}) { +func (rpc *WebRPC) Run() error { logger.Infof("start webrpc on http://%s", rpc.addr) + defer logger.Info("webrpc service closed") l, err := net.Listen("tcp", rpc.addr) if err != nil { - logger.Error("%v", err) - close(quit) - return + return err } - c := make(chan struct{}) - q := make(chan struct{}, 1) + rpc.listener = l + + errC := make(chan error, 1) go func() { if err := http.Serve(l, rpc); err != nil { select { - case <-c: + case <-rpc.quit: return default: // the webrpc service failed unexpectly, notify the - logger.Error("%v", err) - q <- struct{}{} + errC <- err } } }() - select { - case <-quit: - close(c) - l.Close() - case <-q: - close(quit) - } - logger.Info("webrpc quit") - return + return <-errC +} + +// Shutdown close the webrpc service +func (rpc *WebRPC) Shutdown() { + close(rpc.quit) + rpc.listener.Close() } // HandleFunc registers handler function diff --git a/src/api/webrpc/webrpc_test.go b/src/api/webrpc/webrpc_test.go index de3b298fa4..0a39613e52 100644 --- a/src/api/webrpc/webrpc_test.go +++ b/src/api/webrpc/webrpc_test.go @@ -7,6 +7,8 @@ import ( "net/http/httptest" "testing" + "time" + "github.com/skycoin/skycoin/src/cipher" "github.com/skycoin/skycoin/src/coin" "github.com/skycoin/skycoin/src/daemon" @@ -17,10 +19,9 @@ import ( func setup() (*WebRPC, func()) { c := make(chan struct{}) - f := func() { - close(c) - } + rpc, err := New( + "0.0.0.0:8081", ChanBuffSize(1), ThreadNum(1), Gateway(&fakeGateway{}), @@ -29,7 +30,9 @@ func setup() (*WebRPC, func()) { panic(err) } - return rpc, f + return rpc, func() { + rpc.Shutdown() + } } type fakeGateway struct { @@ -65,14 +68,14 @@ func (fg fakeGateway) GetBlocksInDepth(vs []uint64) *visor.ReadableBlocks { return nil } -func (fg fakeGateway) GetUnspentOutputs(filters ...daemon.OutputsFilter) visor.ReadableOutputSet { +func (fg fakeGateway) GetUnspentOutputs(filters ...daemon.OutputsFilter) (visor.ReadableOutputSet, error) { v := decodeOutputStr(outputStr) for _, f := range filters { v.HeadOutputs = f(v.HeadOutputs) v.OutgoingOutputs = f(v.OutgoingOutputs) v.IncommingOutputs = f(v.IncommingOutputs) } - return v + return v, nil } func (fg fakeGateway) GetTransaction(txid cipher.SHA256) (*visor.Transaction, error) { @@ -100,7 +103,7 @@ func (fg fakeGateway) GetTimeNow() uint64 { } func TestNewWebRPC(t *testing.T) { - rpc1, err := New(ChanBuffSize(1), ThreadNum(1), Gateway(&fakeGateway{}), Quit(make(chan struct{}))) + rpc1, err := New("0.0.0.0:8080", ChanBuffSize(1), ThreadNum(1), Gateway(&fakeGateway{}), Quit(make(chan struct{}))) assert.Nil(t, err) assert.NotNil(t, rpc1.mux) assert.NotNil(t, rpc1.handlers) @@ -108,7 +111,7 @@ func TestNewWebRPC(t *testing.T) { } func Test_rpcHandler_HandlerFunc(t *testing.T) { - rpc, err := New(ChanBuffSize(1), ThreadNum(1), Gateway(&fakeGateway{}), Quit(make(chan struct{}))) + rpc, err := New("0.0.0.0:8080", ChanBuffSize(1), ThreadNum(1), Gateway(&fakeGateway{}), Quit(make(chan struct{}))) assert.Nil(t, err) rpc.HandleFunc("get_status", getStatusHandler) err = rpc.HandleFunc("get_status", getStatusHandler) @@ -118,6 +121,9 @@ func Test_rpcHandler_HandlerFunc(t *testing.T) { func Test_rpcHandler_Handler(t *testing.T) { rpc, teardown := setup() defer teardown() + go rpc.Run() + + time.Sleep(50 * time.Millisecond) type args struct { httpMethod string diff --git a/src/daemon/daemon.go b/src/daemon/daemon.go index 04939f41e4..0379874a70 100644 --- a/src/daemon/daemon.go +++ b/src/daemon/daemon.go @@ -206,6 +206,8 @@ type Daemon struct { ipCounts *IPCount // Message handling queue messageEvents chan MessageEvent + // quit channel + quitC chan chan struct{} } // NewDaemon returns a Daemon with primitives allocated @@ -241,8 +243,8 @@ func NewDaemon(config Config) (*Daemon, error) { connectionErrors: make(chan ConnectionError, config.Daemon.OutgoingMax), outgoingConnections: NewOutgoingConnections(config.Daemon.OutgoingMax), pendingConnections: NewPendingConnections(config.Daemon.PendingMax), - messageEvents: make(chan MessageEvent, - config.Pool.EventChannelSize), + messageEvents: make(chan MessageEvent, config.Pool.EventChannelSize), + quitC: make(chan chan struct{}), } d.Gateway = NewGateway(config.Gateway, d) @@ -280,6 +282,11 @@ type MessageEvent struct { // over the quit channel provided to Init. The Daemon run loop must be stopped // before calling this function. func (dm *Daemon) Shutdown() { + // close the daemon loop first + q := make(chan struct{}, 1) + dm.quitC <- q + <-q + dm.Pool.Shutdown() dm.Peers.Shutdown() dm.Visor.Shutdown() @@ -287,25 +294,24 @@ func (dm *Daemon) Shutdown() { // Run main loop for peer/connection management. Send anything to quit to shut it // down -func (dm *Daemon) Run(quit chan struct{}) { +func (dm *Daemon) Run() (err error) { defer func() { if r := recover(); r != nil { - logger.Error("recover:%v\n stack:%v", r, string(debug.Stack())) - } - - // close quit to notify the caller this daemon running loop is stopped - if quit != nil { - close(quit) + logger.Errorf("recover:%v\n stack:%v", r, string(debug.Stack())) } }() - c := make(chan struct{}) + errC := make(chan error) // start visor - go dm.Visor.Run(c) + go func() { + errC <- dm.Visor.Run() + }() if !dm.Config.DisableIncomingConnections { - go dm.Pool.Run(c) + go func() { + errC <- dm.Pool.Run() + }() } // TODO -- run blockchain stuff in its own goroutine @@ -335,7 +341,11 @@ func (dm *Daemon) Run(quit chan struct{}) { for { select { - case <-c: + case err = <-errC: + return + case qc := <-dm.quitC: + qc <- struct{}{} + logger.Info("Daemon closed") return // Remove connections that failed to complete the handshake case <-cullInvalidTicker: @@ -476,7 +486,12 @@ func (dm *Daemon) connectToPeer(p *pex.Peer) error { return errors.New("Not localhost") } - if dm.Pool.Pool.IsConnExist(p.Addr) { + conned, err := dm.Pool.Pool.IsConnExist(p.Addr) + if err != nil { + return err + } + + if conned { return errors.New("Already connected") } @@ -570,21 +585,40 @@ func (dm *Daemon) cullInvalidConnections() { // This method only handles the erroneous people from the DHT, but not // malicious nodes now := util.Now() - addrs := dm.expectingIntroductions.CullInvalidConns(func(addr string, t time.Time) bool { - if !dm.Pool.Pool.IsConnExist(addr) { - return true + addrs, err := dm.expectingIntroductions.CullInvalidConns(func(addr string, t time.Time) (bool, error) { + conned, err := dm.Pool.Pool.IsConnExist(addr) + if err != nil { + return false, err + } + + if !conned { + return true, nil } if t.Add(dm.Config.IntroductionWait).Before(now) { - return true + return true, nil } - return false + return false, nil }) + if err != nil { + logger.Error("expectingIntroduction cull invalid connections failed: %v", err) + return + } + for _, a := range addrs { - if dm.Pool.Pool.IsConnExist(a) { + exist, err := dm.Pool.Pool.IsConnExist(a) + if err != nil { + logger.Error("%v", err) + return + } + + if exist { logger.Info("Removing %s for not sending a version", a) - dm.Pool.Pool.Disconnect(a, ErrDisconnectIntroductionTimeout) + if err := dm.Pool.Pool.Disconnect(a, ErrDisconnectIntroductionTimeout); err != nil { + logger.Error("%v", err) + return + } dm.Peers.RemovePeer(a) } } @@ -633,7 +667,13 @@ func (dm *Daemon) onConnect(e ConnectEvent) { dm.pendingConnections.Remove(a) - if !dm.Pool.Pool.IsConnExist(a) { + exist, err := dm.Pool.Pool.IsConnExist(a) + if err != nil { + logger.Error("%v", err) + return + } + + if !exist { logger.Warning("While processing an onConnect event, no pool " + "connection was found") return diff --git a/src/daemon/gnet/pool.go b/src/daemon/gnet/pool.go index 2b50fffffd..47206202d0 100644 --- a/src/daemon/gnet/pool.go +++ b/src/daemon/gnet/pool.go @@ -37,7 +37,8 @@ var ( ErrDisconnectWriteQueueFull DisconnectReason = errors.New("Write queue full") // ErrDisconnectUnexpectedError unexpected error ErrDisconnectUnexpectedError DisconnectReason = errors.New("Unexpected error encountered") - + // ErrConnectionPoolClosed error message indicates the connection pool is closed + ErrConnectionPoolClosed = errors.New("Connection pool is closed") // Logger logger = logging.MustGetLogger("gnet") ) @@ -144,6 +145,9 @@ func (conn *Connection) String() string { // Close close the connection and write queue func (conn *Connection) Close() { conn.Conn.Close() + close(conn.WriteQueue) + conn.WriteQueue = nil + conn.Buffer = &bytes.Buffer{} } // DisconnectCallback triggered on client disconnect @@ -168,8 +172,8 @@ type ConnectionPool struct { connID int // Listening connection listener net.Listener - // member variables access channel - memChannel chan func() + // operations channel + ops chan func() // quit channel quit chan struct{} } @@ -184,33 +188,30 @@ func NewConnectionPool(c Config, state interface{}) *ConnectionPool { addresses: make(map[string]*Connection), SendResults: make(chan SendResult, c.BroadcastResultSize), messageState: state, - quit: make(chan struct{}), - memChannel: make(chan func()), } return pool } // Run starts the connection pool -func (pool *ConnectionPool) Run(q chan struct{}) { +func (pool *ConnectionPool) Run() error { + // init the quit and operations channel here, in case run this pool again. + pool.quit = make(chan struct{}) + pool.ops = make(chan func()) + go func() { - for { - select { - case memActionFunc := <-pool.memChannel: - // this goroutine will handle all member variable's writing and reading actions. - memActionFunc() - case <-pool.quit: - return - } + for op := range pool.ops { + op() } + + logger.Info("Connection pool closed") }() // start the connection accept loop addr := fmt.Sprintf("%s:%v", pool.Config.Address, pool.Config.Port) ln, err := net.Listen("tcp", addr) if err != nil { - logger.Error("%v", err) - return + return err } pool.listener = ln @@ -220,16 +221,15 @@ func (pool *ConnectionPool) Run(q chan struct{}) { conn, err := ln.Accept() if err != nil { // When Accept() returns with a non-nill error, we check the quit - // channel to see if we should continue or quit. If quit, then we quit. + // channel to see if we should continue or quit . If quit, then we quit. // Otherwise we continue select { case <-pool.quit: - return + close(pool.ops) + return nil default: // without the default case the select will block. - logger.Error("%v", err) - close(q) - return + continue } } @@ -239,31 +239,43 @@ func (pool *ConnectionPool) Run(q chan struct{}) { // Shutdown gracefully shutdown the connection pool func (pool *ConnectionPool) Shutdown() { + pool.strand(func() error { + pool.addresses = map[string]*Connection{} + pool.pool = map[int]*Connection{} + return nil + }) + close(pool.quit) pool.listener.Close() pool.listener = nil } // strand ensures all read and write action of pool's member variable are in one thread. -func (pool *ConnectionPool) strand(f func()) { +func (pool *ConnectionPool) strand(f func() error) (err error) { + defer func() { + // send on closed operation channel will panic. + if r := recover(); r != nil { + err = ErrConnectionPoolClosed + } + }() + q := make(chan struct{}) - pool.memChannel <- func() { + pool.ops <- func() { defer close(q) - f() + err = f() } <-q + return } // NewConnection creates a new Connection around a net.Conn. Trying to make a connection -// to an address that is already connected will panic. +// to an address that is already connected will failed. func (pool *ConnectionPool) NewConnection(conn net.Conn, solicited bool) (*Connection, error) { a := conn.RemoteAddr().String() var nc *Connection - var err error - pool.strand(func() { + if err := pool.strand(func() error { if pool.addresses[a] != nil { - err = fmt.Errorf("Already connected to %s", a) - return + return fmt.Errorf("Already connected to %s", a) } pool.connID++ nc = NewConnection(pool, pool.connID, conn, @@ -271,9 +283,12 @@ func (pool *ConnectionPool) NewConnection(conn net.Conn, solicited bool) (*Conne pool.pool[nc.ID] = nc pool.addresses[a] = nc - }) + return nil + }); err != nil { + return nil, err + } - return nc, err + return nc, nil } // ListeningAddress returns address, on which the ConnectionPool @@ -288,25 +303,21 @@ func (pool *ConnectionPool) ListeningAddress() (net.Addr, error) { // Creates a Connection and begins its read and write loop func (pool *ConnectionPool) handleConnection(conn net.Conn, solicited bool) { - a := conn.RemoteAddr().String() - if pool.IsConnExist(a) { - logger.Error("Connection %s already exists", a) + addr := conn.RemoteAddr().String() + exist, err := pool.IsConnExist(addr) + if err != nil { + logger.Error("%v", err) return } - var c *Connection - reason := ErrDisconnectUnexpectedError - defer func() { - logger.Debug("End connection handler of %s", conn.RemoteAddr()) - // notify to exist the receive message loop - if c != nil { - pool.Disconnect(c.Addr(), reason) - } - }() + if exist { + logger.Error("Connection %s already exists", addr) + return + } c, err := pool.NewConnection(conn, solicited) if err != nil { - logger.Error("%v", err) + logger.Error("Create connection failed: %v", err) return } @@ -314,64 +325,71 @@ func (pool *ConnectionPool) handleConnection(conn net.Conn, solicited bool) { pool.Config.ConnectCallback(c.Addr(), solicited) } - msgChan := make(chan []byte, 10) - errChan := make(chan error) - go readLoop(c, pool.Config.ReadTimeout, pool.Config.MaxMessageLength, msgChan, errChan) - for { - select { - case m := <-c.WriteQueue: - if m == nil { - continue - } - err := sendMessage(conn, m, pool.Config.WriteTimeout) - sr := newSendResult(c.Addr(), m, err) - pool.SendResults <- sr - if err != nil { - reason = ErrDisconnectWriteFailed - return - } - pool.updateLastSent(c.Addr(), Now()) - case msg := <-msgChan: - dc, err := pool.receiveMessage(c, msg) - if err != nil { - reason = ErrDisconnectMalformedMessage - return - } - if dc != nil { - reason = dc - return - } - case err := <-errChan: - if err != nil { - reason = err + msgC := make(chan []byte, 10) + errC := make(chan error, 1) + + go func() { + errC <- readLoop(c, pool.Config.ReadTimeout, pool.Config.MaxMessageLength, msgC) + }() + + qc := make(chan chan struct{}) + go func() { + for { + select { + case m := <-c.WriteQueue: + if m == nil { + continue + } + err := sendMessage(conn, m, pool.Config.WriteTimeout) + sr := newSendResult(c.Addr(), m, err) + pool.SendResults <- sr + if err != nil { + errC <- err + return + } + + if err := pool.updateLastSent(c.Addr(), Now()); err != nil { + errC <- err + return + } + case msg := <-msgC: + if err := pool.receiveMessage(c, msg); err != nil { + errC <- err + return + } + case q := <-qc: + q <- struct{}{} return } } + }() + + e := <-errC + q := make(chan struct{}, 1) + qc <- q + <-q + + if err := pool.Disconnect(c.Addr(), e); err != nil { + logger.Error("Disconnect failed: %v", err) } } -func readLoop(conn *Connection, timeout time.Duration, maxMsgLen int, msgChan chan []byte, errChan chan error) { +func readLoop(conn *Connection, timeout time.Duration, maxMsgLen int, msgChan chan []byte) error { // read data from connection - defer func() { - logger.Debug("End readLoop of %s", conn.Addr()) - }() reader := bufio.NewReader(conn.Conn) buf := make([]byte, 1024) - var rerr error for { deadline := time.Time{} if timeout != 0 { deadline = time.Now().Add(timeout) } if err := conn.Conn.SetReadDeadline(deadline); err != nil { - rerr = ErrDisconnectSetReadDeadlineFailed - break + return ErrDisconnectSetReadDeadlineFailed } data, err := readData(reader, buf) if err != nil { - rerr = err - break + return err } if data == nil { @@ -379,13 +397,14 @@ func readLoop(conn *Connection, timeout time.Duration, maxMsgLen int, msgChan ch } // write date to buffer. - conn.Buffer.Write(data) + if _, err := conn.Buffer.Write(data); err != nil { + return err + } // decode data datas, err := decodeData(conn.Buffer, maxMsgLen) if err != nil { - rerr = err - break + return err } for _, d := range datas { @@ -394,17 +413,10 @@ func readLoop(conn *Connection, timeout time.Duration, maxMsgLen int, msgChan ch select { case msgChan <- d: default: - return + return errors.New("The msgChan has no receiver") } } } - - if rerr != nil { - select { - case errChan <- rerr: - default: - } - } } func readData(reader io.Reader, buf []byte) ([]byte, error) { @@ -460,52 +472,63 @@ func decodeData(buf *bytes.Buffer, maxMsgLength int) ([][]byte, error) { } // IsConnExist check if the connection of address does exist -func (pool *ConnectionPool) IsConnExist(addr string) bool { +func (pool *ConnectionPool) IsConnExist(addr string) (bool, error) { var exist bool - pool.strand(func() { + if err := pool.strand(func() error { if _, ok := pool.addresses[addr]; ok { exist = true } - }) - return exist + return nil + }); err != nil { + return false, fmt.Errorf("Check connection existence failed: %v ", err) + } + + return exist, nil } -func (pool *ConnectionPool) updateLastSent(addr string, t time.Time) { - pool.strand(func() { +func (pool *ConnectionPool) updateLastSent(addr string, t time.Time) error { + return pool.strand(func() error { if conn, ok := pool.addresses[addr]; ok { conn.LastSent = t } + return nil }) } -func (pool *ConnectionPool) updateLastRecv(addr string, t time.Time) { - pool.strand(func() { +func (pool *ConnectionPool) updateLastRecv(addr string, t time.Time) error { + return pool.strand(func() error { if conn, ok := pool.addresses[addr]; ok { conn.LastReceived = t } + return nil }) } // GetConnection returns a connection copy if exist -func (pool *ConnectionPool) GetConnection(addr string) *Connection { - var conn Connection - var exist bool - pool.strand(func() { +func (pool *ConnectionPool) GetConnection(addr string) (*Connection, error) { + var conn *Connection + if err := pool.strand(func() error { if c, ok := pool.addresses[addr]; ok { // copy connection - conn = *c - exist = true + var cc = *c + conn = &cc } - }) - if exist { - return &conn + return nil + }); err != nil { + return nil, err } - return nil + + return conn, nil } // Connect to an address func (pool *ConnectionPool) Connect(address string) error { - if pool.IsConnExist(address) { + exist, err := pool.IsConnExist(address) + if err != nil { + return err + } + + if exist { return nil } @@ -521,40 +544,47 @@ func (pool *ConnectionPool) Connect(address string) error { // Disconnect removes a connection from the pool by address, and passes a Disconnection to // the DisconnectCallback -func (pool *ConnectionPool) Disconnect(addr string, r DisconnectReason) { +func (pool *ConnectionPool) Disconnect(addr string, r DisconnectReason) error { var exist bool - pool.strand(func() { + if err := pool.strand(func() error { if conn, ok := pool.addresses[addr]; ok { exist = true delete(pool.pool, conn.ID) delete(pool.addresses, addr) conn.Close() } - }) + return nil + }); err != nil { + return err + } if pool.Config.DisconnectCallback != nil && exist { pool.Config.DisconnectCallback(addr, r) } + return nil } // GetConnections returns an copy of pool connections -func (pool *ConnectionPool) GetConnections() []Connection { +func (pool *ConnectionPool) GetConnections() ([]Connection, error) { conns := []Connection{} - pool.strand(func() { + if err := pool.strand(func() error { for _, conn := range pool.pool { conns = append(conns, *conn) } - }) - return conns + return nil + }); err != nil { + return nil, err + } + return conns, nil } // Size returns the pool size -func (pool *ConnectionPool) Size() int { - var l int - pool.strand(func() { +func (pool *ConnectionPool) Size() (l int, err error) { + err = pool.strand(func() error { l = len(pool.pool) + return nil }) - return l + return } // SendMessage sends a Message to a Connection and pushes the result onto the @@ -564,7 +594,7 @@ func (pool *ConnectionPool) SendMessage(addr string, msg Message) error { logger.Debug("Send, Msg Type: %s", reflect.TypeOf(msg)) } var msgQueueFull bool - pool.strand(func() { + if err := pool.strand(func() error { if conn, ok := pool.addresses[addr]; ok { select { case conn.WriteQueue <- msg: @@ -572,7 +602,10 @@ func (pool *ConnectionPool) SendMessage(addr string, msg Message) error { msgQueueFull = true } } - }) + return nil + }); err != nil { + return err + } if msgQueueFull { return ErrDisconnectWriteQueueFull @@ -582,16 +615,15 @@ func (pool *ConnectionPool) SendMessage(addr string, msg Message) error { } // BroadcastMessage sends a Message to all connections in the Pool. -func (pool *ConnectionPool) BroadcastMessage(msg Message) (err error) { +func (pool *ConnectionPool) BroadcastMessage(msg Message) error { if pool.Config.DebugPrint { logger.Debug("Broadcast, Msg Type: %s", reflect.TypeOf(msg)) } fullWriteQueue := []string{} - pool.strand(func() { + if err := pool.strand(func() error { if len(pool.pool) == 0 { - err = errors.New("Connection pool is empty") - return + return errors.New("Connection pool is empty") } for _, conn := range pool.pool { @@ -602,61 +634,80 @@ func (pool *ConnectionPool) BroadcastMessage(msg Message) (err error) { } } if len(fullWriteQueue) == len(pool.pool) { - err = errors.New("There's no available connection in pool") + return errors.New("There's no available connection in pool") } - }) + + return nil + }); err != nil { + return err + } for _, addr := range fullWriteQueue { - pool.Disconnect(addr, ErrDisconnectWriteQueueFull) + if err := pool.Disconnect(addr, ErrDisconnectWriteQueueFull); err != nil { + return err + } } - return + return nil } // Unpacks incoming bytes to a Message and calls the message handler. If // the bytes cannot be converted to a Message, the error is returned as the // first return value. Otherwise, error will be nil and DisconnectReason will // be the value returned from the message handler. -func (pool *ConnectionPool) receiveMessage(c *Connection, msg []byte) (DisconnectReason, error) { +func (pool *ConnectionPool) receiveMessage(c *Connection, msg []byte) error { m, err := convertToMessage(c.ID, msg, pool.Config.DebugPrint) if err != nil { - return nil, err + return err + } + if err := pool.updateLastRecv(c.Addr(), Now()); err != nil { + return err } - pool.updateLastRecv(c.Addr(), Now()) - return m.Handle(NewMessageContext(c), pool.messageState), nil + return m.Handle(NewMessageContext(c), pool.messageState) } // SendPings sends a ping if our last message sent was over pingRate ago -func (pool *ConnectionPool) SendPings(rate time.Duration, msg Message) { +func (pool *ConnectionPool) SendPings(rate time.Duration, msg Message) error { now := util.Now() var addrs []string - pool.strand(func() { + if err := pool.strand(func() error { for _, conn := range pool.pool { if conn.LastSent.Add(rate).Before(now) { addrs = append(addrs, conn.Addr()) } } - }) + return nil + }); err != nil { + return err + } for _, a := range addrs { - pool.SendMessage(a, msg) + if err := pool.SendMessage(a, msg); err != nil { + return err + } } + + return nil } // ClearStaleConnections removes connections that have not sent a message in too long -func (pool *ConnectionPool) ClearStaleConnections(idleLimit time.Duration, reason DisconnectReason) { +func (pool *ConnectionPool) ClearStaleConnections(idleLimit time.Duration, reason DisconnectReason) error { now := Now() idleConns := []string{} - pool.strand(func() { + if err := pool.strand(func() error { for _, conn := range pool.pool { if conn.LastReceived.Add(idleLimit).Before(now) { idleConns = append(idleConns, conn.Addr()) } } - }) + return nil + }); err != nil { + return err + } for _, a := range idleConns { pool.Disconnect(a, reason) } + return nil } // Now returns the current UTC time diff --git a/src/daemon/gnet/pool_test.go b/src/daemon/gnet/pool_test.go index 407e5a9bef..3a62f1c09e 100644 --- a/src/daemon/gnet/pool_test.go +++ b/src/daemon/gnet/pool_test.go @@ -55,7 +55,7 @@ func TestNewConnection(t *testing.T) { cfg.ConnectionWriteQueueSize = 101 p := NewConnectionPool(cfg, nil) defer p.Shutdown() - p.Run() + go p.Run() wait() conn, err := net.Dial("tcp", addr) assert.Nil(t, err) @@ -79,7 +79,7 @@ func TestNewConnectionAlreadyConnected(t *testing.T) { cfg.Address = address p := NewConnectionPool(cfg, nil) defer p.Shutdown() - p.Run() + go p.Run() wait() conn, err := net.Dial("tcp", addr) assert.Nil(t, err) @@ -103,14 +103,9 @@ func TestAcceptConnections(t *testing.T) { } p := NewConnectionPool(cfg, nil) defer p.Shutdown() - // go func() { - p.Run() - assert.NotNil(t, p.listener) - // }() - // go handleXConnections(p, 1) - // go p.AcceptConnections() - // Make a successful connection + go p.Run() wait() + assert.NotNil(t, p.listener) // assert.NotNil(t, p.listener) c, err := net.Dial("tcp", addr) assert.Nil(t, err) @@ -134,21 +129,14 @@ func TestAcceptConnections(t *testing.T) { func TestListeningAddress(t *testing.T) { wait() - t.Run("listening", func(t *testing.T) { - // cleanupNet() - cfg := NewConfig() - cfg.Address = "" - cfg.Port = 0 - p := NewConnectionPool(cfg, nil) - defer p.Shutdown() - // assert.Nil(t, p.StartListen()) - p.Run() - wait() - // addr, err := p.ListeningAddress() - // assert.Nil(t, err) - // assert.NotNil(t, addr) - t.Log("ListeningAddress: ", addr) - }) + cfg := NewConfig() + cfg.Address = "" + cfg.Port = 0 + p := NewConnectionPool(cfg, nil) + defer p.Shutdown() + go p.Run() + wait() + t.Log("ListeningAddress: ", addr) } func TestStartListen(t *testing.T) { @@ -164,15 +152,13 @@ func TestStartListen(t *testing.T) { } p := NewConnectionPool(cfg, nil) defer p.Shutdown() - p.Run() + go p.Run() wait() _, err := net.Dial("tcp", addr) assert.Nil(t, err) wait() assert.True(t, called) assert.NotNil(t, p.listener) - // assert.Nil(t, p.StartListen()) - // p.StopListen() } func TestStartListenTwice(t *testing.T) { @@ -182,9 +168,9 @@ func TestStartListenTwice(t *testing.T) { cfg.Address = address p := NewConnectionPool(cfg, nil) defer p.Shutdown() - p.Run() + go p.Run() wait() - assert.Panics(t, func() { p.Run() }) + assert.NotNil(t, p.Run()) } func TestStartListenFailed(t *testing.T) { @@ -193,12 +179,11 @@ func TestStartListenFailed(t *testing.T) { cfg.Port = uint16(port) cfg.Address = address p := NewConnectionPool(cfg, nil) - p.Run() + go p.Run() defer p.Shutdown() wait() q := NewConnectionPool(cfg, nil) - // // // Can't listen on the same port - assert.Panics(t, func() { q.Run() }) + assert.NotNil(t, q.Run()) } func TestStopListen(t *testing.T) { @@ -207,7 +192,7 @@ func TestStopListen(t *testing.T) { cfg.Port = uint16(port) cfg.Address = address p := NewConnectionPool(cfg, nil) - p.Run() + go p.Run() wait() assert.NotNil(t, p.listener) conn, err := net.Dial("tcp", addr) @@ -221,7 +206,7 @@ func TestStopListen(t *testing.T) { assert.Equal(t, len(p.pool), 0) assert.Equal(t, len(p.addresses), 0) // Listening again should have no error - p.Run() + go p.Run() wait() p.Shutdown() wait() @@ -243,14 +228,16 @@ func TestHandleConnection(t *testing.T) { called = true } p := NewConnectionPool(cfg, nil) - p.Run() + go p.Run() defer p.Shutdown() wait() conn, err := net.Dial("tcp", addr) assert.Nil(t, err) wait() assert.True(t, called) - assert.True(t, p.IsConnExist(conn.LocalAddr().String())) + exist, err := p.IsConnExist(conn.LocalAddr().String()) + assert.Nil(t, err) + assert.True(t, exist) called = false delete(p.addresses, conn.LocalAddr().String()) delete(p.pool, 1) @@ -264,7 +251,9 @@ func TestHandleConnection(t *testing.T) { go p.handleConnection(conn, true) wait() - assert.True(t, p.IsConnExist(conn.RemoteAddr().String())) + exist, err = p.IsConnExist(conn.RemoteAddr().String()) + assert.Nil(t, err) + assert.True(t, exist) assert.True(t, called) called = false assert.Equal(t, len(p.addresses), 1) @@ -278,7 +267,7 @@ func TestConnect(t *testing.T) { cfg.Address = address // cfg.Port p := NewConnectionPool(cfg, nil) - p.Run() + go p.Run() wait() err := p.Connect(addr) wait() @@ -316,7 +305,8 @@ func TestConnectNoTimeout(t *testing.T) { cfg.DialTimeout = 0 cfg.Port++ p := NewConnectionPool(cfg, nil) - p.Run() + go p.Run() + wait() defer p.Shutdown() err := p.Connect(addr) wait() @@ -329,7 +319,7 @@ func TestDisconnect(t *testing.T) { cfg.Port = uint16(port) cfg.Address = address p := NewConnectionPool(cfg, nil) - p.Run() + go p.Run() defer p.Shutdown() wait() _, err := net.Dial("tcp", addr) @@ -377,9 +367,11 @@ func TestGetConnections(t *testing.T) { p.pool[c.ID] = c p.pool[d.ID] = d p.pool[e.ID] = e - p.Run() + go p.Run() + wait() defer p.Shutdown() - conns := p.GetConnections() + conns, err := p.GetConnections() + assert.Nil(t, err) assert.Equal(t, len(conns), 3) m := make(map[int]*Connection, 3) for i, c := range conns { @@ -397,7 +389,7 @@ func TestConnectionReadLoop(t *testing.T) { cfg.Port = uint16(port) cfg.Address = address p := NewConnectionPool(cfg, nil) - p.Run() + go p.Run() defer p.Shutdown() wait() @@ -465,7 +457,8 @@ func TestProcessConnectionBuffers(t *testing.T) { cfg.Port = uint16(port) cfg.Address = address p := NewConnectionPool(cfg, nil) - p.Run() + go p.Run() + wait() defer p.Shutdown() conn, err := net.Dial("tcp", addr) @@ -566,7 +559,7 @@ func TestConnectionWriteLoop(t *testing.T) { cfg.Port = uint16(port) cfg.Address = address p := NewConnectionPool(cfg, nil) - p.Run() + go p.Run() defer p.Shutdown() wait() _, err := net.Dial("tcp", addr) @@ -616,7 +609,7 @@ func TestPoolSendMessage(t *testing.T) { cfg.WriteTimeout = time.Second // cfg.ConnectionWriteQueueSize = 1 p := NewConnectionPool(cfg, nil) - p.Run() + go p.Run() defer p.Shutdown() wait() assert.NotEqual(t, p.Config.ConnectionWriteQueueSize, 0) @@ -652,7 +645,7 @@ func TestPoolBroadcastMessage(t *testing.T) { cfg.Address = address cfg.Port = uint16(port) p := NewConnectionPool(cfg, nil) - p.Run() + go p.Run() defer p.Shutdown() wait() @@ -687,32 +680,36 @@ func TestPoolReceiveMessage(t *testing.T) { RegisterMessage(ErrorPrefix, ErrorMessage{}) VerifyMessages() - c := &Connection{} - assert.True(t, c.LastReceived.IsZero()) + // c := &Connection{ + // Conn: NewDummyConn(addr), + // Buffer: &bytes.Buffer{}, + // WriteQueue: make(chan Message), + // } p := NewConnectionPool(NewConfig(), nil) + go p.Run() + wait() + defer p.Shutdown() + c := NewConnection(p, 1, NewDummyConn(addr), 10, true) + // assert.True(t, c.LastReceived.IsZero()) // Valid message received b := make([]byte, 0) b = append(b, BytePrefix[:]...) b = append(b, byte(7)) - reason, err := p.receiveMessage(c, b) + err := p.receiveMessage(c, b) assert.Nil(t, err) assert.False(t, c.LastReceived.IsZero()) - assert.Nil(t, reason) // Invalid byte message received b = []byte{1} - reason, err = p.receiveMessage(c, b) + err = p.receiveMessage(c, b) assert.NotNil(t, err) - assert.Nil(t, reason) // Valid message, but handler returns a DisconnectReason b = make([]byte, 0) b = append(b, ErrorPrefix[:]...) - reason, err = p.receiveMessage(c, b) - assert.Nil(t, err) - assert.NotNil(t, reason) - assert.Equal(t, reason.Error(), "Bad") + err = p.receiveMessage(c, b) + assert.Equal(t, err.Error(), "Bad") } // /* Helpers */ diff --git a/src/daemon/peers.go b/src/daemon/peers.go index 573b9dcc7e..e79323355c 100644 --- a/src/daemon/peers.go +++ b/src/daemon/peers.go @@ -101,15 +101,13 @@ func (ps *Peers) Shutdown() error { return nil } - logger.Info("Saving Peer List") - err := ps.Peers.Save(ps.Config.DataDirectory) if err != nil { logger.Warning("Failed to save peer database") logger.Warning("Reason: %v", err) return err } - logger.Info("Shutdown peers") + logger.Info("Peers saved") return nil } diff --git a/src/daemon/pex/pex.go b/src/daemon/pex/pex.go index d6ec523b0a..d568d9b5f2 100644 --- a/src/daemon/pex/pex.go +++ b/src/daemon/pex/pex.go @@ -444,9 +444,6 @@ func LoadPeerlist(dir string) (*Peerlist, error) { if err := util.LoadJSON(fn, &peerlist.peers); err != nil { return nil, err } - // if err != nil { - // logger.Notice("LoadPeerList Failed: %s", err) - // } return &peerlist, nil } @@ -455,8 +452,6 @@ func LoadPeerlist(dir string) (*Peerlist, error) { type Pex struct { // All known peers *Peerlist - // Ignored peers - // Blacklist Blacklist // If false, localhost peers will be rejected from the peerlist AllowLocalhost bool maxPeers int @@ -465,8 +460,7 @@ type Pex struct { // NewPex creates pex func NewPex(maxPeers int) *Pex { return &Pex{ - Peerlist: &Peerlist{peers: make(map[string]*Peer, maxPeers)}, - // Blacklist: make(Blacklist, 0), + Peerlist: &Peerlist{peers: make(map[string]*Peer, maxPeers)}, maxPeers: maxPeers, AllowLocalhost: false, } diff --git a/src/daemon/pool.go b/src/daemon/pool.go index 54808d6e58..070b9a3d2a 100644 --- a/src/daemon/pool.go +++ b/src/daemon/pool.go @@ -74,14 +74,13 @@ func NewPool(c PoolConfig, d *Daemon) *Pool { func (pool *Pool) Shutdown() { if pool.Pool != nil { pool.Pool.Shutdown() - logger.Info("Shutdown pool") } } // Run starts listening on the configured Port // no goroutine -func (pool *Pool) Run(quit chan struct{}) { - pool.Pool.Run(quit) +func (pool *Pool) Run() error { + return pool.Pool.Run() } // Send a ping if our last message sent was over pingRate ago diff --git a/src/daemon/rpc.go b/src/daemon/rpc.go index 88b3f812b5..5a711250e6 100644 --- a/src/daemon/rpc.go +++ b/src/daemon/rpc.go @@ -51,7 +51,12 @@ func (rpc RPC) GetConnection(d *Daemon, addr string) *Connection { return nil } - c := d.Pool.Pool.GetConnection(addr) + c, err := d.Pool.Pool.GetConnection(addr) + if err != nil { + logger.Error("%v", err) + return nil + } + if c == nil { return nil } @@ -78,8 +83,21 @@ func (rpc RPC) GetConnections(d *Daemon) *Connections { if d.Pool.Pool == nil { return nil } - conns := make([]*Connection, 0, d.Pool.Pool.Size()) - for _, c := range d.Pool.Pool.GetConnections() { + + l, err := d.Pool.Pool.Size() + if err != nil { + logger.Error("%v", err) + return nil + } + + conns := make([]*Connection, 0, l) + cs, err := d.Pool.Pool.GetConnections() + if err != nil { + logger.Error("%v", err) + return nil + } + + for _, c := range cs { if c.Solicited { conn := rpc.GetConnection(d, c.Addr()) if conn != nil { diff --git a/src/daemon/storage.go b/src/daemon/storage.go index dd42b1f24b..0e3f606ce3 100644 --- a/src/daemon/storage.go +++ b/src/daemon/storage.go @@ -13,7 +13,7 @@ type store struct { lk sync.Mutex } -type storeFunc func(*store) +type storeFunc func(*store) error type matchFunc func(k interface{}, v interface{}) bool func (s *store) setValue(k interface{}, v interface{}) { @@ -29,10 +29,10 @@ func (s *store) getValue(k interface{}) (interface{}, bool) { return v, ok } -func (s *store) do(sf storeFunc) { +func (s *store) do(sf storeFunc) error { s.lk.Lock() defer s.lk.Unlock() - sf(s) + return sf(s) } func (s *store) remove(k interface{}) { @@ -53,7 +53,7 @@ type ExpectIntroductions struct { } // CullMatchFunc function for checking if the connection need to be culled -type CullMatchFunc func(addr string, t time.Time) bool +type CullMatchFunc func(addr string, t time.Time) (bool, error) // NewExpectIntroductions creates a ExpectIntroduction instance func NewExpectIntroductions() *ExpectIntroductions { @@ -75,19 +75,28 @@ func (ei *ExpectIntroductions) Remove(addr string) { } // CullInvalidConns cull connections that match the matchFunc -func (ei *ExpectIntroductions) CullInvalidConns(f CullMatchFunc) []string { +func (ei *ExpectIntroductions) CullInvalidConns(f CullMatchFunc) ([]string, error) { var addrs []string - ei.do(func(s *store) { + if err := ei.do(func(s *store) error { for k, v := range s.value { addr := k.(string) t := v.(time.Time) - if f(addr, t) { + ok, err := f(addr, t) + if err != nil { + return err + } + + if ok { addrs = append(addrs, addr) delete(s.value, k) } } - }) - return addrs + return nil + }); err != nil { + return nil, err + } + + return addrs, nil } // Get returns the time of speicific address @@ -220,15 +229,16 @@ func NewMirrorConnections() *MirrorConnections { // Add adds mirror connection func (mc *MirrorConnections) Add(mirror uint32, ip string, port uint16) { - mc.do(func(s *store) { + mc.do(func(s *store) error { if m, ok := s.value[mirror]; ok { m.(map[string]uint16)[ip] = port - return + return nil } m := make(map[string]uint16) m[ip] = port s.value[mirror] = m + return nil }) } @@ -236,21 +246,22 @@ func (mc *MirrorConnections) Add(mirror uint32, ip string, port uint16) { func (mc *MirrorConnections) Get(mirror uint32, ip string) (uint16, bool) { var port uint16 var exist bool - mc.do(func(s *store) { + mc.do(func(s *store) error { if m, ok := s.value[mirror]; ok { port, exist = m.(map[string]uint16)[ip] - return } + return nil }) return port, exist } // Remove removes port of ip for specific mirror func (mc *MirrorConnections) Remove(mirror uint32, ip string) { - mc.do(func(s *store) { + mc.do(func(s *store) error { if m, ok := s.value[mirror]; ok { delete(m.(map[string]uint16), ip) } + return nil }) } @@ -275,30 +286,32 @@ func NewIPCount() *IPCount { // Increase increases one for specific ip func (ic *IPCount) Increase(ip string) { - ic.do(func(s *store) { + ic.do(func(s *store) error { if v, ok := s.value[ip]; ok { c := v.(int) c++ s.value[ip] = c - return + return nil } s.value[ip] = 1 + return nil }) } // Decrease decreases one for specific ip func (ic *IPCount) Decrease(ip string) { - ic.do(func(s *store) { + ic.do(func(s *store) error { if v, ok := s.value[ip]; ok { c := v.(int) if c <= 1 { delete(s.value, ip) - return + return nil } c-- s.value[ip] = c } + return nil }) } diff --git a/src/daemon/storage_test.go b/src/daemon/storage_test.go index bef6fbbba5..d2a7b23f0c 100644 --- a/src/daemon/storage_test.go +++ b/src/daemon/storage_test.go @@ -124,12 +124,14 @@ func TestCullInvalidConnections(t *testing.T) { vc := make(chan string, 3) wg.Add(2) go func(w *sync.WaitGroup) { - as := ei.CullInvalidConns(func(addr string, tm time.Time) bool { + as, err := ei.CullInvalidConns(func(addr string, tm time.Time) (bool, error) { if addr == "a" || addr == "b" { - return true + return true, nil } - return false + return false, nil }) + assert.Nil(t, err) + for _, s := range as { vc <- s } @@ -138,13 +140,15 @@ func TestCullInvalidConnections(t *testing.T) { go func(w *sync.WaitGroup) { // w.Add(1) - as := ei.CullInvalidConns(func(addr string, tm time.Time) bool { + as, err := ei.CullInvalidConns(func(addr string, tm time.Time) (bool, error) { if addr == "c" { - return true + return true, nil } - return false + return false, nil }) + assert.Nil(t, err) + for _, s := range as { vc <- s } diff --git a/src/daemon/visor.go b/src/daemon/visor.go index 4a4b270e63..63b2a565a2 100644 --- a/src/daemon/visor.go +++ b/src/daemon/visor.go @@ -67,7 +67,6 @@ type Visor struct { // Peer-reported blockchain length. Use to estimate download progress blockchainLengths map[string]uint64 reqC chan reqFunc // all request will go through this channel, to keep writing and reading member variable thread safe. - cxt context.Context Shutdown context.CancelFunc } @@ -96,32 +95,33 @@ func NewVisor(c VisorConfig) (*Visor, error) { reqC: make(chan reqFunc, 100), } - var cancel func() - vs.cxt, cancel = context.WithCancel(context.Background()) vs.Shutdown = func() { // close the visor closeVs() - - // cancel the cxt - cancel() } return vs, nil } // Run starts the visor -func (vs *Visor) Run(quit chan struct{}) { - q := make(chan struct{}) - go vs.v.Run(q) +func (vs *Visor) Run() error { + defer logger.Info("Visor closed") + errC := make(chan error, 1) + go func() { + // vs.Shutdown will notify the vs.v.Run to return. + errC <- vs.v.Run() + }() for { select { - case <-q: - close(quit) - case <-vs.cxt.Done(): - return + case err := <-errC: + return err case req := <-vs.reqC: - req(vs.cxt) + func() { + cxt, cancel := context.WithDeadline(context.Background(), time.Now().Add(3*time.Second)) + defer cancel() + req(cxt) + }() } } } @@ -241,7 +241,14 @@ func (vs *Visor) RequestBlocksFromAddr(pool *Pool, addr string) error { var err error vs.strand(func() { m := NewGetBlocksMessage(vs.v.HeadBkSeq(), vs.Config.BlocksResponseCount) - if !pool.Pool.IsConnExist(addr) { + + exist, er := pool.Pool.IsConnExist(addr) + if er != nil { + err = er + return + } + + if !exist { err = fmt.Errorf("Tried to send GetBlocksMessage to %s, but we're "+ "not connected", addr) return @@ -278,7 +285,13 @@ func (vs *Visor) BroadcastTransaction(t coin.Transaction, pool *Pool) { return } m := NewGiveTxnsMessage(coin.Transactions{t}) - logger.Debug("Broadcasting GiveTxnsMessage to %d conns", pool.Pool.Size()) + l, err := pool.Pool.Size() + if err != nil { + logger.Error("Broadcast GivenTxnsMessage failed: %v", err) + return + } + + logger.Debug("Broadcasting GiveTxnsMessage to %d conns", l) pool.Pool.BroadcastMessage(m) } diff --git a/src/visor/blockchain.go b/src/visor/blockchain.go index 2715226ca7..460588967d 100644 --- a/src/visor/blockchain.go +++ b/src/visor/blockchain.go @@ -590,14 +590,13 @@ func (bc *Blockchain) VerifySigs(pubKey cipher.PubKey, sigs *blockdb.BlockSigs) for i := uint64(0); i <= head.Seq(); i++ { b := bc.GetBlockInDepth(i) if b == nil { - return fmt.Errorf("no block in depth %v", i) + return fmt.Errorf("No block in depth %v", i) } // get sig sig, err := sigs.Get(b.HashHeader()) if err != nil { - logger.Info("block sig:%v", i) - return err + return fmt.Errorf("Verify signature of block in depth: %d failed: %v", i, err) } if err := cipher.VerifySignature(pubKey, sig, b.HashHeader()); err != nil { diff --git a/src/visor/blockchain_parser.go b/src/visor/blockchain_parser.go index 5a83a03665..7c509a6b4c 100644 --- a/src/visor/blockchain_parser.go +++ b/src/visor/blockchain_parser.go @@ -44,27 +44,24 @@ func (bcp *BlockchainParser) BlockListener(b coin.Block) { // Run starts blockchain parser, the q channel will be // closed to notify the invoker that the running process // is going to shutdown. -func (bcp *BlockchainParser) Run(q chan struct{}) { +func (bcp *BlockchainParser) Run() error { logger.Info("Blockchain parser start") + defer logger.Info("Blockchain parser closed") // parse to the blockchain head headSeq := bcp.bc.Head().Seq() if err := bcp.parseTo(headSeq); err != nil { - logger.Error("%v", err) - close(q) - return + return err } for { select { case cc := <-bcp.closing: cc <- struct{}{} - return + return nil case b := <-bcp.blkC: if err := bcp.parseTo(b.Head.BkSeq); err != nil { - logger.Error("%v", err) - close(q) - return + return err } } } @@ -72,10 +69,9 @@ func (bcp *BlockchainParser) Run(q chan struct{}) { // Stop close the block parsing process. func (bcp *BlockchainParser) Stop() { - cc := make(chan struct{}) + cc := make(chan struct{}, 1) bcp.closing <- cc <-cc - logger.Info("blockchain parser stopped") } func (bcp *BlockchainParser) parseTo(bcHeight uint64) error { diff --git a/src/visor/visor.go b/src/visor/visor.go index c8fd430a1d..a0e2e6a5f1 100644 --- a/src/visor/visor.go +++ b/src/visor/visor.go @@ -132,8 +132,8 @@ func openDB(dbFile string) (*bolt.DB, func(), error) { } return db, func() { - logger.Info("close db") db.Close() + logger.Info("DB closed") }, nil } @@ -190,18 +190,19 @@ func NewVisor(c Config) (*Visor, VsClose, error) { } return v, func() { - closeDB() v.bcParser.Stop() + closeDB() }, nil } // Run starts the visor process -func (vs *Visor) Run(q chan struct{}) { +func (vs *Visor) Run() error { + errC := make(chan error, 1) + go func() { logger.Info("Verify signature...") if err := vs.Blockchain.VerifySigs(vs.Config.BlockchainPubkey, vs.blockSigs); err != nil { - logger.Error("Invalid block signatures: %v", err) - close(q) + errC <- fmt.Errorf("Invalid block signatures: %v", err) return } logger.Info("Signature verify success") @@ -214,27 +215,34 @@ func (vs *Visor) Run(q chan struct{}) { vs.Config.GenesisCoinVolume, vs.Config.GenesisTimestamp) if err != nil { - logger.Error("%v", err) - close(q) - return + return err } - logger.Debug("create genesis block") + logger.Debug("Create genesis block") // record the signature of genesis block if vs.Config.IsMaster { sb := vs.SignBlock(b) - vs.blockSigs.Add(&sb) - logger.Info("genesis block signature=%s", sb.Sig.Hex()) + if err := vs.blockSigs.Add(&sb); err != nil { + return err + } + + logger.Info("Genesis block signature=%s", sb.Sig.Hex()) } else { - vs.blockSigs.Add(&coin.SignedBlock{ + if err := vs.blockSigs.Add(&coin.SignedBlock{ Block: b, Sig: vs.Config.GenesisSignature, - }) + }); err != nil { + return err + } } } - vs.bcParser.Run(q) + go func() { + errC <- vs.bcParser.Run() + }() + + return <-errC } // GenesisPreconditions panics if conditions for genesis block are not met