From 46184acddaf535fe20d204d881ae2117cdb92723 Mon Sep 17 00:00:00 2001 From: Linsto Hu <91648523+linstohu@users.noreply.github.com> Date: Tue, 26 Dec 2023 10:02:22 +0800 Subject: [PATCH] support graceful close for websocket client (#3) --- .../coinmfutures/websocketmarket/client.go | 70 +++++--- .../websocketmarket/client_test.go | 156 ++++++++---------- .../websocketmarket/subscriptions.go | 2 +- binance/coinmfutures/websocketmarket/vars.go | 4 + .../europeanoptions/websocketmarket/client.go | 70 +++++--- .../websocketmarket/client_test.go | 55 ++---- .../websocketmarket/subscriptions.go | 2 +- .../europeanoptions/websocketmarket/vars.go | 4 + .../websocketuserdata/client.go | 70 +++++--- .../websocketuserdata/client_test.go | 24 +-- .../websocketuserdata/subscriptions.go | 2 +- .../europeanoptions/websocketuserdata/vars.go | 4 + binance/spot/websocketmarket/client.go | 70 +++++--- binance/spot/websocketmarket/client_test.go | 25 +-- binance/spot/websocketmarket/subscriptions.go | 2 +- binance/spot/websocketmarket/vars.go | 4 + binance/usdmfutures/websocketmarket/client.go | 70 +++++--- .../websocketmarket/client_test.go | 125 +++++++------- .../websocketmarket/subscriptions.go | 2 +- binance/usdmfutures/websocketmarket/vars.go | 4 + htx/spot/accountws/client.go | 91 ++++++---- htx/spot/accountws/client_test.go | 19 +-- htx/spot/accountws/subscriptions.go | 2 +- htx/spot/accountws/vars.go | 4 + htx/spot/marketws/client.go | 82 ++++++--- htx/spot/marketws/client_test.go | 25 +-- htx/spot/marketws/subscriptions.go | 2 +- htx/spot/marketws/vars.go | 4 + htx/usdm/accountws/client.go | 102 ++++++++---- htx/usdm/accountws/client_test.go | 24 +-- htx/usdm/accountws/vars.go | 4 + htx/usdm/marketws/client.go | 80 ++++++--- htx/usdm/marketws/client_test.go | 15 +- htx/usdm/marketws/subscriptions.go | 2 +- htx/usdm/marketws/vars.go | 4 + woox/websocket/client.go | 78 ++++++--- woox/websocket/client_test.go | 73 ++------ woox/websocket/subscriptions.go | 2 +- woox/websocket/vars.go | 4 + 39 files changed, 807 insertions(+), 575 deletions(-) diff --git a/binance/coinmfutures/websocketmarket/client.go b/binance/coinmfutures/websocketmarket/client.go index dc2ca53..14b8dea 100644 --- a/binance/coinmfutures/websocketmarket/client.go +++ b/binance/coinmfutures/websocketmarket/client.go @@ -41,7 +41,9 @@ type CoinMarginedMarketStreamClient struct { // logger logger *slog.Logger - ctx context.Context + stopCtx context.Context + cancel context.CancelFunc + conn *websocket.Conn mu sync.RWMutex isConnected bool @@ -56,13 +58,14 @@ type CoinMarginedMarketStreamClient struct { } type CoinMarginedMarketStreamCfg struct { - BaseURL string `validate:"required"` - Debug bool - // Logger + Debug bool + BaseURL string `validate:"required"` + AutoReconnect bool `validate:"required"` + Logger *slog.Logger } -func NewMarketStreamClient(ctx context.Context, cfg *CoinMarginedMarketStreamCfg) (*CoinMarginedMarketStreamClient, error) { +func NewMarketStreamClient(cfg *CoinMarginedMarketStreamCfg) (*CoinMarginedMarketStreamClient, error) { if err := validator.New().Struct(cfg); err != nil { return nil, err } @@ -72,8 +75,7 @@ func NewMarketStreamClient(ctx context.Context, cfg *CoinMarginedMarketStreamCfg debug: cfg.Debug, logger: cfg.Logger, - ctx: ctx, - autoReconnect: true, + autoReconnect: cfg.AutoReconnect, subscriptions: cmap.New[struct{}](), emitter: emission.NewEmitter(), @@ -83,12 +85,33 @@ func NewMarketStreamClient(ctx context.Context, cfg *CoinMarginedMarketStreamCfg cli.logger = slog.Default() } - err := cli.start() + return cli, nil +} + +func (u *CoinMarginedMarketStreamClient) Open() error { + if u.stopCtx != nil { + return fmt.Errorf("%s: ws is already open", logPrefix) + } + + u.stopCtx, u.cancel = context.WithCancel(context.Background()) + + err := u.start() if err != nil { - return nil, err + return err } - return cli, nil + return nil +} + +func (u *CoinMarginedMarketStreamClient) Close() error { + if u.stopCtx == nil { + return fmt.Errorf("%s: ws is not open", logPrefix) + } + + u.cancel() + u.stopCtx = nil + + return nil } func (u *CoinMarginedMarketStreamClient) start() error { @@ -99,7 +122,7 @@ func (u *CoinMarginedMarketStreamClient) start() error { for i := 0; i < MaxTryTimes; i++ { conn, _, err := u.connect() if err != nil { - u.logger.Info(fmt.Sprintf("connect error, times(%v), error: %s", i, err.Error())) + u.logger.Info(fmt.Sprintf("%s: connect error, times(%v), error: %s", logPrefix, i, err.Error())) tm := (i + 1) * 5 time.Sleep(time.Duration(tm) * time.Second) continue @@ -111,6 +134,8 @@ func (u *CoinMarginedMarketStreamClient) start() error { return errors.New("connect failed") } + u.logger.Info(fmt.Sprintf("%s: connect success, base_url: %s", logPrefix, u.baseURL)) + u.setIsConnected(true) u.resubscribe() @@ -141,15 +166,14 @@ func (u *CoinMarginedMarketStreamClient) reconnect() { u.setIsConnected(false) - u.logger.Info("disconnect, then reconnect...") - time.Sleep(1 * time.Second) select { - case <-u.ctx.Done(): - u.logger.Info(fmt.Sprintf("never reconnect, %s", u.ctx.Err())) + case <-u.stopCtx.Done(): + u.logger.Info(fmt.Sprintf("%s: reconnection exits", logPrefix)) return default: + u.logger.Info(fmt.Sprintf("%s: try to reconnect...", logPrefix)) u.start() } } @@ -185,24 +209,28 @@ func (u *CoinMarginedMarketStreamClient) IsConnected() bool { func (u *CoinMarginedMarketStreamClient) readMessages() { for { select { - case <-u.ctx.Done(): - u.logger.Info(fmt.Sprintf("context done, error: %s", u.ctx.Err().Error())) + case <-u.stopCtx.Done(): + u.logger.Info(fmt.Sprintf("%s: ready to close...", logPrefix)) if err := u.close(); err != nil { - u.logger.Info(fmt.Sprintf("websocket connection closed error, %s", err.Error())) + u.logger.Error(fmt.Sprintf("%s: connection closed error, %s", logPrefix, err.Error())) + return } + u.logger.Info(fmt.Sprintf("%s: connection closed success", logPrefix)) return default: var msg utils.AnyMessage err := u.conn.ReadJSON(&msg) if err != nil { - u.logger.Info(fmt.Sprintf("read object error, %s", err)) + u.logger.Info(fmt.Sprintf("%s: read message error, %s", logPrefix, err)) if err := u.close(); err != nil { - u.logger.Info(fmt.Sprintf("websocket connection closed error, %s", err.Error())) + u.logger.Error(fmt.Sprintf("%s: connection closed error, %s", logPrefix, err.Error())) + return } + u.logger.Info(fmt.Sprintf("%s: connection closed success", logPrefix)) return } @@ -212,7 +240,7 @@ func (u *CoinMarginedMarketStreamClient) readMessages() { case msg.SubscribedMessage != nil: err := u.handle(msg.SubscribedMessage) if err != nil { - u.logger.Info(fmt.Sprintf("handle message error: %s", err.Error())) + u.logger.Info(fmt.Sprintf("%s: handle message error: %s", logPrefix, err.Error())) } } } diff --git a/binance/coinmfutures/websocketmarket/client_test.go b/binance/coinmfutures/websocketmarket/client_test.go index 56eb422..0a85ba3 100644 --- a/binance/coinmfutures/websocketmarket/client_test.go +++ b/binance/coinmfutures/websocketmarket/client_test.go @@ -15,22 +15,23 @@ * limitations under the License. */ -package websocketmarket +package websocketmarket_test import ( - "context" "fmt" "testing" + coinmws "github.com/linstohu/nexapi/binance/coinmfutures/websocketmarket" spottypes "github.com/linstohu/nexapi/binance/spot/websocketmarket/types" usdmtypes "github.com/linstohu/nexapi/binance/usdmfutures/websocketmarket/types" "github.com/stretchr/testify/assert" ) -func testNewMarketStreamClient(ctx context.Context, t *testing.T) *CoinMarginedMarketStreamClient { - cli, err := NewMarketStreamClient(ctx, &CoinMarginedMarketStreamCfg{ - BaseURL: CoinMarginedMarketStreamBaseURL, - Debug: true, +func testNewMarketStreamClient(t *testing.T) *coinmws.CoinMarginedMarketStreamClient { + cli, err := coinmws.NewMarketStreamClient(&coinmws.CoinMarginedMarketStreamCfg{ + Debug: false, + BaseURL: coinmws.CoinMarginedMarketStreamBaseURL, + AutoReconnect: true, }) if err != nil { @@ -41,10 +42,9 @@ func testNewMarketStreamClient(ctx context.Context, t *testing.T) *CoinMarginedM } func TestSubscribeAggTrade(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) + err := cli.Open() + assert.Nil(t, err) topic, err := cli.GetAggTradeTopic("btcusd_perp") assert.Nil(t, err) @@ -65,19 +65,18 @@ func TestSubscribeAggTrade(t *testing.T) { } func TestSubscribeIndexPrice(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) + err := cli.Open() + assert.Nil(t, err) - topic, err := cli.GetIndexPriceTopic(&IndexPriceTopicParam{ + topic, err := cli.GetIndexPriceTopic(&coinmws.IndexPriceTopicParam{ Pair: "btcusd", UpdateSpeed: "1s", }) assert.Nil(t, err) cli.AddListener(topic, func(e any) { - indexPrice, ok := e.(*IndexPrice) + indexPrice, ok := e.(*coinmws.IndexPrice) if !ok { return } @@ -92,19 +91,18 @@ func TestSubscribeIndexPrice(t *testing.T) { } func TestSubscribeMarkPrice(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) + err := cli.Open() + assert.Nil(t, err) - topic, err := cli.GetMarketPriceTopic(&MarkPriceTopicParam{ + topic, err := cli.GetMarketPriceTopic(&coinmws.MarkPriceTopicParam{ Symbol: "btcusd_perp", UpdateSpeed: "1s", }) assert.Nil(t, err) cli.AddListener(topic, func(e any) { - markprice, ok := e.(*MarkPrice) + markprice, ok := e.(*coinmws.MarkPrice) if !ok { return } @@ -119,19 +117,18 @@ func TestSubscribeMarkPrice(t *testing.T) { } func TestSubscribePairMarkPrice(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) + err := cli.Open() + assert.Nil(t, err) - topic, err := cli.GetPairMarketPriceTopic(&PairMarkPriceTopicParam{ + topic, err := cli.GetPairMarketPriceTopic(&coinmws.PairMarkPriceTopicParam{ Pair: "btcusd", UpdateSpeed: "1s", }) assert.Nil(t, err) cli.AddListener(topic, func(e any) { - markprices, ok := e.([]*MarkPrice) + markprices, ok := e.([]*coinmws.MarkPrice) if !ok { return } @@ -149,12 +146,11 @@ func TestSubscribePairMarkPrice(t *testing.T) { } func TestSubscribeKline(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) + err := cli.Open() + assert.Nil(t, err) - topic, err := cli.GetKlineTopic(&KlineTopicParam{ + topic, err := cli.GetKlineTopic(&coinmws.KlineTopicParam{ Symbol: "btcusd_perp", Interval: "1m", }) @@ -176,16 +172,15 @@ func TestSubscribeKline(t *testing.T) { } func TestSubscribeMiniTicker(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) + err := cli.Open() + assert.Nil(t, err) topic, err := cli.GetMiniTickerTopic("btcusd_perp") assert.Nil(t, err) cli.AddListener(topic, func(e any) { - ticker, ok := e.(*MiniTicker) + ticker, ok := e.(*coinmws.MiniTicker) if !ok { return } @@ -200,16 +195,15 @@ func TestSubscribeMiniTicker(t *testing.T) { } func TestSubscribeAllMiniTicker(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) + err := cli.Open() + assert.Nil(t, err) topic, err := cli.GetAllMarketMiniTickersTopic() assert.Nil(t, err) cli.AddListener(topic, func(e any) { - tickers, ok := e.([]*MiniTicker) + tickers, ok := e.([]*coinmws.MiniTicker) if !ok { return } @@ -226,16 +220,15 @@ func TestSubscribeAllMiniTicker(t *testing.T) { } func TestSubscribeTicker(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) + err := cli.Open() + assert.Nil(t, err) topic, err := cli.GetTickerTopic("btcusd_perp") assert.Nil(t, err) cli.AddListener(topic, func(e any) { - ticker, ok := e.(*Ticker) + ticker, ok := e.(*coinmws.Ticker) if !ok { return } @@ -250,16 +243,15 @@ func TestSubscribeTicker(t *testing.T) { } func TestSubscribeAllTicker(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) + err := cli.Open() + assert.Nil(t, err) topic, err := cli.GetAllMarketTickersTopic() assert.Nil(t, err) cli.AddListener(topic, func(e any) { - tickers, ok := e.([]*Ticker) + tickers, ok := e.([]*coinmws.Ticker) if !ok { return } @@ -276,16 +268,15 @@ func TestSubscribeAllTicker(t *testing.T) { } func TestSubscribeBookTicker(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) + err := cli.Open() + assert.Nil(t, err) topic, err := cli.GetBookTickerTopic("btcusd_perp") assert.Nil(t, err) cli.AddListener(topic, func(e any) { - book, ok := e.(*BookTicker) + book, ok := e.(*coinmws.BookTicker) if !ok { return } @@ -300,16 +291,15 @@ func TestSubscribeBookTicker(t *testing.T) { } func TestSubscribeAllBookTickers(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) + err := cli.Open() + assert.Nil(t, err) topic, err := cli.GetAllBookTickersTopic() assert.Nil(t, err) cli.AddListener(topic, func(e any) { - book, ok := e.(*BookTicker) + book, ok := e.(*coinmws.BookTicker) if !ok { return } @@ -324,16 +314,15 @@ func TestSubscribeAllBookTickers(t *testing.T) { } func TestSubscribeLiquidationOrder(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) + err := cli.Open() + assert.Nil(t, err) topic, err := cli.GetLiquidationOrderTopic("btcusd_perp") assert.Nil(t, err) cli.AddListener(topic, func(e any) { - order, ok := e.(*LiquidationOrder) + order, ok := e.(*coinmws.LiquidationOrder) if !ok { return } @@ -348,16 +337,15 @@ func TestSubscribeLiquidationOrder(t *testing.T) { } func TestSubscribeAllLiquidationOrders(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) + err := cli.Open() + assert.Nil(t, err) topic, err := cli.GetAllLiquidationOrdersTopic() assert.Nil(t, err) cli.AddListener(topic, func(e any) { - order, ok := e.(*LiquidationOrder) + order, ok := e.(*coinmws.LiquidationOrder) if !ok { return } @@ -372,12 +360,11 @@ func TestSubscribeAllLiquidationOrders(t *testing.T) { } func TestSubscribeBookDepth(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) + err := cli.Open() + assert.Nil(t, err) - topic, err := cli.GetBookDepthTopic(&BookDepthTopicParam{ + topic, err := cli.GetBookDepthTopic(&coinmws.BookDepthTopicParam{ Symbol: "btcusd_perp", Level: 5, UpdateSpeed: "500ms", @@ -385,7 +372,7 @@ func TestSubscribeBookDepth(t *testing.T) { assert.Nil(t, err) cli.AddListener(topic, func(e any) { - book, ok := e.(*OrderbookDepth) + book, ok := e.(*coinmws.OrderbookDepth) if !ok { return } @@ -400,12 +387,11 @@ func TestSubscribeBookDepth(t *testing.T) { } func TestSubscribeBookDiffDepth(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) + err := cli.Open() + assert.Nil(t, err) - topic, err := cli.GetBookDiffDepthTopic(&BookDiffDepthTopicParam{ + topic, err := cli.GetBookDiffDepthTopic(&coinmws.BookDiffDepthTopicParam{ Symbol: "btcusd_perp", UpdateSpeed: "500ms", }) @@ -414,7 +400,7 @@ func TestSubscribeBookDiffDepth(t *testing.T) { fmt.Println(topic) cli.AddListener(topic, func(e any) { - book, ok := e.(*OrderbookDepth) + book, ok := e.(*coinmws.OrderbookDepth) if !ok { return } diff --git a/binance/coinmfutures/websocketmarket/subscriptions.go b/binance/coinmfutures/websocketmarket/subscriptions.go index bfc6f53..5a028ae 100644 --- a/binance/coinmfutures/websocketmarket/subscriptions.go +++ b/binance/coinmfutures/websocketmarket/subscriptions.go @@ -37,7 +37,7 @@ func (u *CoinMarginedMarketStreamClient) UnSubscribe(topics []string) error { func (u *CoinMarginedMarketStreamClient) handle(msg *utils.SubscribedMessage) error { if u.debug { - u.logger.Info(fmt.Sprintf("subscribed message, stream: %s", msg.Stream)) + u.logger.Info(fmt.Sprintf("%s: subscribed message, stream: %s", logPrefix, msg.Stream)) } switch { diff --git a/binance/coinmfutures/websocketmarket/vars.go b/binance/coinmfutures/websocketmarket/vars.go index c102f46..799a814 100644 --- a/binance/coinmfutures/websocketmarket/vars.go +++ b/binance/coinmfutures/websocketmarket/vars.go @@ -22,6 +22,10 @@ var ( CombinedStreamRouter = "/stream" ) +const ( + logPrefix = "binance::coinm::websocketmarket" +) + const ( MaxTryTimes = 5 ) diff --git a/binance/europeanoptions/websocketmarket/client.go b/binance/europeanoptions/websocketmarket/client.go index 21b2497..1d4e5c1 100644 --- a/binance/europeanoptions/websocketmarket/client.go +++ b/binance/europeanoptions/websocketmarket/client.go @@ -41,7 +41,9 @@ type OptionsMarketStreamClient struct { // logger logger *slog.Logger - ctx context.Context + stopCtx context.Context + cancel context.CancelFunc + conn *websocket.Conn mu sync.RWMutex isConnected bool @@ -56,13 +58,14 @@ type OptionsMarketStreamClient struct { } type OptionsMarketStreamCfg struct { - BaseURL string `validate:"required"` - Debug bool - // Logger + Debug bool + BaseURL string `validate:"required"` + AutoReconnect bool `validate:"required"` + Logger *slog.Logger } -func NewMarketStreamClient(ctx context.Context, cfg *OptionsMarketStreamCfg) (*OptionsMarketStreamClient, error) { +func NewMarketStreamClient(cfg *OptionsMarketStreamCfg) (*OptionsMarketStreamClient, error) { if err := validator.New().Struct(cfg); err != nil { return nil, err } @@ -72,8 +75,7 @@ func NewMarketStreamClient(ctx context.Context, cfg *OptionsMarketStreamCfg) (*O debug: cfg.Debug, logger: cfg.Logger, - ctx: ctx, - autoReconnect: true, + autoReconnect: cfg.AutoReconnect, subscriptions: cmap.New[struct{}](), emitter: emission.NewEmitter(), @@ -83,12 +85,33 @@ func NewMarketStreamClient(ctx context.Context, cfg *OptionsMarketStreamCfg) (*O cli.logger = slog.Default() } - err := cli.start() + return cli, nil +} + +func (o *OptionsMarketStreamClient) Open() error { + if o.stopCtx != nil { + return fmt.Errorf("%s: ws is already open", logPrefix) + } + + o.stopCtx, o.cancel = context.WithCancel(context.Background()) + + err := o.start() if err != nil { - return nil, err + return err } - return cli, nil + return nil +} + +func (o *OptionsMarketStreamClient) Close() error { + if o.stopCtx == nil { + return fmt.Errorf("%s: ws is not open", logPrefix) + } + + o.cancel() + o.stopCtx = nil + + return nil } func (o *OptionsMarketStreamClient) start() error { @@ -99,7 +122,7 @@ func (o *OptionsMarketStreamClient) start() error { for i := 0; i < MaxTryTimes; i++ { conn, _, err := o.connect() if err != nil { - o.logger.Info(fmt.Sprintf("connect error, times(%v), error: %s", i, err.Error())) + o.logger.Info(fmt.Sprintf("%s: connect error, times(%v), error: %s", logPrefix, i, err.Error())) tm := (i + 1) * 5 time.Sleep(time.Duration(tm) * time.Second) continue @@ -111,6 +134,8 @@ func (o *OptionsMarketStreamClient) start() error { return errors.New("connect failed") } + o.logger.Info(fmt.Sprintf("%s: connect success, base_url: %s", logPrefix, o.baseURL)) + o.setIsConnected(true) o.resubscribe() @@ -141,15 +166,14 @@ func (o *OptionsMarketStreamClient) reconnect() { o.setIsConnected(false) - o.logger.Info("disconnect, then reconnect...") - time.Sleep(1 * time.Second) select { - case <-o.ctx.Done(): - o.logger.Info(fmt.Sprintf("never reconnect, %s", o.ctx.Err())) + case <-o.stopCtx.Done(): + o.logger.Info(fmt.Sprintf("%s: reconnection exits", logPrefix)) return default: + o.logger.Info(fmt.Sprintf("%s: try to reconnect...", logPrefix)) o.start() } } @@ -185,24 +209,28 @@ func (o *OptionsMarketStreamClient) IsConnected() bool { func (o *OptionsMarketStreamClient) readMessages() { for { select { - case <-o.ctx.Done(): - o.logger.Info(fmt.Sprintf("context done, error: %s", o.ctx.Err().Error())) + case <-o.stopCtx.Done(): + o.logger.Info(fmt.Sprintf("%s: ready to close...", logPrefix)) if err := o.close(); err != nil { - o.logger.Info(fmt.Sprintf("websocket connection closed error, %s", err.Error())) + o.logger.Error(fmt.Sprintf("%s: connection closed error, %s", logPrefix, err.Error())) + return } + o.logger.Info(fmt.Sprintf("%s: connection closed success", logPrefix)) return default: var msg utils.AnyMessage err := o.conn.ReadJSON(&msg) if err != nil { - o.logger.Info(fmt.Sprintf("read object error, %s", err)) + o.logger.Info(fmt.Sprintf("read message error, %s", err)) if err := o.close(); err != nil { - o.logger.Info(fmt.Sprintf("websocket connection closed error, %s", err.Error())) + o.logger.Error(fmt.Sprintf("%s: connection closed error, %s", logPrefix, err.Error())) + return } + o.logger.Info(fmt.Sprintf("%s: connection closed success", logPrefix)) return } @@ -212,7 +240,7 @@ func (o *OptionsMarketStreamClient) readMessages() { case msg.SubscribedMessage != nil: err := o.handle(msg.SubscribedMessage) if err != nil { - o.logger.Info(fmt.Sprintf("handle message error: %s", err.Error())) + o.logger.Info(fmt.Sprintf("%s: handle message error: %s", logPrefix, err.Error())) } } } diff --git a/binance/europeanoptions/websocketmarket/client_test.go b/binance/europeanoptions/websocketmarket/client_test.go index 58fde97..338e6af 100644 --- a/binance/europeanoptions/websocketmarket/client_test.go +++ b/binance/europeanoptions/websocketmarket/client_test.go @@ -18,17 +18,17 @@ package websocketmarket import ( - "context" "fmt" "testing" "github.com/stretchr/testify/assert" ) -func testNewMarketStreamClient(ctx context.Context, t *testing.T) *OptionsMarketStreamClient { - cli, err := NewMarketStreamClient(ctx, &OptionsMarketStreamCfg{ - BaseURL: OptionsMarketStreamBaseURL, - Debug: true, +func testNewMarketStreamClient(t *testing.T) *OptionsMarketStreamClient { + cli, err := NewMarketStreamClient(&OptionsMarketStreamCfg{ + Debug: true, + BaseURL: OptionsMarketStreamBaseURL, + AutoReconnect: true, }) if err != nil { @@ -39,10 +39,7 @@ func testNewMarketStreamClient(ctx context.Context, t *testing.T) *OptionsMarket } func TestSubscribeTrade(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) topic, err := cli.GetTradeTopic("ETH") assert.Nil(t, err) @@ -63,10 +60,7 @@ func TestSubscribeTrade(t *testing.T) { } func TestSubscribeIndexPrice(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) topic, err := cli.GetIndexPriceTopic("ETHUSDT") assert.Nil(t, err) @@ -87,10 +81,7 @@ func TestSubscribeIndexPrice(t *testing.T) { } func TestSubscribeMarkPrice(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) topic, err := cli.GetMarkPriceTopic("ETH") assert.Nil(t, err) @@ -113,10 +104,7 @@ func TestSubscribeMarkPrice(t *testing.T) { } func TestSubscribeKline(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) topic, err := cli.GetKlineTopic(&KlineTopicParam{ Symbol: "ETH-230525-1825-C", @@ -140,10 +128,7 @@ func TestSubscribeKline(t *testing.T) { } func TestSubscribe24HourTicker(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) topic, err := cli.Get24HourTickerTopic("ETH-230525-1825-C") assert.Nil(t, err) @@ -164,10 +149,7 @@ func TestSubscribe24HourTicker(t *testing.T) { } func TestSubscribeUnderlying24HourTicker(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) topic, err := cli.Get24HourTickerByUnderlyingAndexpirationTopic("ETH", "230525") assert.Nil(t, err) @@ -190,10 +172,7 @@ func TestSubscribeUnderlying24HourTicker(t *testing.T) { } func TestSubscribeOpenInterest(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) topic, err := cli.GetOpenInterestTopic("BTC", "230525") assert.Nil(t, err) @@ -216,10 +195,7 @@ func TestSubscribeOpenInterest(t *testing.T) { } func TestSubscribeOrderbook(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) topic, err := cli.GetBookDepthTopic(&BookDepthTopicParam{ Symbol: "BTC-230602-25000-P", @@ -244,10 +220,7 @@ func TestSubscribeOrderbook(t *testing.T) { } func TestSubscribeDiffOrderbook(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) topic, err := cli.GetBookDiffDepthTopic("BTC-230602-25000-P") assert.Nil(t, err) diff --git a/binance/europeanoptions/websocketmarket/subscriptions.go b/binance/europeanoptions/websocketmarket/subscriptions.go index af331ff..649039b 100644 --- a/binance/europeanoptions/websocketmarket/subscriptions.go +++ b/binance/europeanoptions/websocketmarket/subscriptions.go @@ -35,7 +35,7 @@ func (o *OptionsMarketStreamClient) UnSubscribe(topics []string) error { func (o *OptionsMarketStreamClient) handle(msg *utils.SubscribedMessage) error { if o.debug { - o.logger.Info(fmt.Sprintf("subscribed message, stream: %s", msg.Stream)) + o.logger.Info(fmt.Sprintf("%s: subscribed message, stream: %s", logPrefix, msg.Stream)) } switch { diff --git a/binance/europeanoptions/websocketmarket/vars.go b/binance/europeanoptions/websocketmarket/vars.go index e45e23d..9391e88 100644 --- a/binance/europeanoptions/websocketmarket/vars.go +++ b/binance/europeanoptions/websocketmarket/vars.go @@ -22,6 +22,10 @@ var ( CombinedStreamRouter = "/stream" ) +const ( + logPrefix = "binance::options::websocketmarket" +) + const ( MaxTryTimes = 5 ) diff --git a/binance/europeanoptions/websocketuserdata/client.go b/binance/europeanoptions/websocketuserdata/client.go index f1e1453..38d3497 100644 --- a/binance/europeanoptions/websocketuserdata/client.go +++ b/binance/europeanoptions/websocketuserdata/client.go @@ -41,7 +41,9 @@ type OptionsUserDataStreamClient struct { baseURL string key, secret string - ctx context.Context + stopCtx context.Context + cancel context.CancelFunc + conn *websocket.Conn mu sync.RWMutex isConnected bool @@ -58,12 +60,13 @@ type OptionsUserDataStreamCfg struct { // Logger Logger *slog.Logger - BaseURL string `validate:"required"` - Key string `validate:"required"` - Secret string `validate:"required"` + BaseURL string `validate:"required"` + Key string `validate:"required"` + Secret string `validate:"required"` + AutoReconnect bool `validate:"required"` } -func NewUserDataStreamClient(ctx context.Context, cfg *OptionsUserDataStreamCfg) (*OptionsUserDataStreamClient, error) { +func NewUserDataStreamClient(cfg *OptionsUserDataStreamCfg) (*OptionsUserDataStreamClient, error) { if err := validator.New().Struct(cfg); err != nil { return nil, err } @@ -76,8 +79,7 @@ func NewUserDataStreamClient(ctx context.Context, cfg *OptionsUserDataStreamCfg) key: cfg.Key, secret: cfg.Secret, - ctx: ctx, - autoReconnect: true, + autoReconnect: cfg.AutoReconnect, emitter: emission.NewEmitter(), } @@ -86,12 +88,33 @@ func NewUserDataStreamClient(ctx context.Context, cfg *OptionsUserDataStreamCfg) cli.logger = slog.Default() } - err := cli.start() + return cli, nil +} + +func (o *OptionsUserDataStreamClient) Open() error { + if o.stopCtx != nil { + return fmt.Errorf("%s: ws is already open", logPrefix) + } + + o.stopCtx, o.cancel = context.WithCancel(context.Background()) + + err := o.start() if err != nil { - return nil, err + return err } - return cli, nil + return nil +} + +func (o *OptionsUserDataStreamClient) Close() error { + if o.stopCtx == nil { + return fmt.Errorf("%s: ws is not open", logPrefix) + } + + o.cancel() + o.stopCtx = nil + + return nil } func (o *OptionsUserDataStreamClient) start() error { @@ -103,7 +126,7 @@ func (o *OptionsUserDataStreamClient) start() error { for i := 0; i < MaxTryTimes; i++ { conn, _, err := o.connect() if err != nil { - o.logger.Info(fmt.Sprintf("connect error, times(%v), error: %s", i, err.Error())) + o.logger.Info(fmt.Sprintf("%s: connect error, times(%v), error: %s", logPrefix, i, err.Error())) tm := (i + 1) * 5 time.Sleep(time.Duration(tm) * time.Second) continue @@ -115,6 +138,8 @@ func (o *OptionsUserDataStreamClient) start() error { return errors.New("connect failed") } + o.logger.Info(fmt.Sprintf("%s: connect success, base_url: %s", logPrefix, o.baseURL)) + o.setIsConnected(true) if o.autoReconnect { @@ -198,17 +223,16 @@ func (o *OptionsUserDataStreamClient) reconnect() { o.setIsConnected(false) - o.logger.Info("disconnect, then reconnect...") - close(o.heartCancel) time.Sleep(1 * time.Second) select { - case <-o.ctx.Done(): - o.logger.Info(fmt.Sprintf("never reconnect, %s", o.ctx.Err())) + case <-o.stopCtx.Done(): + o.logger.Info(fmt.Sprintf("%s: reconnection exits", logPrefix)) return default: + o.logger.Info(fmt.Sprintf("%s: try to reconnect...", logPrefix)) o.start() } } @@ -248,7 +272,7 @@ func (o *OptionsUserDataStreamClient) heartbeat() { case <-t.C: err := o.updateListenKey() if err != nil { - o.logger.Info(fmt.Sprintf("websocket update listen-key error, %s", err.Error())) + o.logger.Info(fmt.Sprintf("%s: update listen-key error, %s", logPrefix, err.Error())) } case <-o.heartCancel: return @@ -259,13 +283,15 @@ func (o *OptionsUserDataStreamClient) heartbeat() { func (o *OptionsUserDataStreamClient) readMessages() { for { select { - case <-o.ctx.Done(): - o.logger.Info(fmt.Sprintf("context done, error: %s", o.ctx.Err().Error())) + case <-o.stopCtx.Done(): + o.logger.Info(fmt.Sprintf("%s: ready to close...", logPrefix)) if err := o.close(); err != nil { - o.logger.Info(fmt.Sprintf("websocket connection closed error, %s", err.Error())) + o.logger.Error(fmt.Sprintf("%s: connection closed error, %s", logPrefix, err.Error())) + return } + o.logger.Info(fmt.Sprintf("%s: connection closed success", logPrefix)) return default: _, bytes, err := o.conn.ReadMessage() @@ -273,15 +299,17 @@ func (o *OptionsUserDataStreamClient) readMessages() { o.logger.Info(fmt.Sprintf("read message error, %s", err)) if err := o.close(); err != nil { - o.logger.Info(fmt.Sprintf("websocket connection closed error, %s", err.Error())) + o.logger.Error(fmt.Sprintf("%s: connection closed error, %s", logPrefix, err.Error())) + return } + o.logger.Info(fmt.Sprintf("%s: connection closed success", logPrefix)) return } err = o.handle(bytes) if err != nil { - o.logger.Info(fmt.Sprintf("handle message error: %s", err.Error())) + o.logger.Info(fmt.Sprintf("%s: handle message error: %s", logPrefix, err.Error())) } } } diff --git a/binance/europeanoptions/websocketuserdata/client_test.go b/binance/europeanoptions/websocketuserdata/client_test.go index 485fac0..da9fc8a 100644 --- a/binance/europeanoptions/websocketuserdata/client_test.go +++ b/binance/europeanoptions/websocketuserdata/client_test.go @@ -18,18 +18,18 @@ package websocketuserdata import ( - "context" "fmt" "os" "testing" ) -func testNewUserDataStreamClient(ctx context.Context, t *testing.T) *OptionsUserDataStreamClient { - cli, err := NewUserDataStreamClient(ctx, &OptionsUserDataStreamCfg{ - BaseURL: OptionsUserDataStreamBaseURL, - Key: os.Getenv("BINANCE_KEY"), - Secret: os.Getenv("BINANCE_SECRET"), - Debug: true, +func testNewUserDataStreamClient(t *testing.T) *OptionsUserDataStreamClient { + cli, err := NewUserDataStreamClient(&OptionsUserDataStreamCfg{ + Debug: true, + BaseURL: OptionsUserDataStreamBaseURL, + AutoReconnect: true, + Key: os.Getenv("BINANCE_KEY"), + Secret: os.Getenv("BINANCE_SECRET"), }) if err != nil { @@ -40,10 +40,7 @@ func testNewUserDataStreamClient(ctx context.Context, t *testing.T) *OptionsUser } func TestSubscribeAccountData(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewUserDataStreamClient(ctx, t) + cli := testNewUserDataStreamClient(t) topic := cli.GenAccountDataTopic() @@ -68,10 +65,7 @@ func TestSubscribeAccountData(t *testing.T) { } func TestSubscribeOrderUpdate(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewUserDataStreamClient(ctx, t) + cli := testNewUserDataStreamClient(t) topic := cli.GenOrderUpdateTopic() diff --git a/binance/europeanoptions/websocketuserdata/subscriptions.go b/binance/europeanoptions/websocketuserdata/subscriptions.go index d3b4c15..8dd540e 100644 --- a/binance/europeanoptions/websocketuserdata/subscriptions.go +++ b/binance/europeanoptions/websocketuserdata/subscriptions.go @@ -34,7 +34,7 @@ func (o *OptionsUserDataStreamClient) handle(origind []byte) error { eventType := string(pb.GetStringBytes("e")) if o.debug { - o.logger.Info(fmt.Sprintf("subscribed message, event-type: %s", eventType)) + o.logger.Info(fmt.Sprintf("%s: subscribed message, event-type: %s", logPrefix, eventType)) } switch eventType { diff --git a/binance/europeanoptions/websocketuserdata/vars.go b/binance/europeanoptions/websocketuserdata/vars.go index 0ced330..6f908ae 100644 --- a/binance/europeanoptions/websocketuserdata/vars.go +++ b/binance/europeanoptions/websocketuserdata/vars.go @@ -22,6 +22,10 @@ var ( UserDataStreamRouter = "/ws/" ) +const ( + logPrefix = "binance::options::websocketuserdata" +) + const ( MaxTryTimes = 5 ) diff --git a/binance/spot/websocketmarket/client.go b/binance/spot/websocketmarket/client.go index 89f3f45..d34ceec 100644 --- a/binance/spot/websocketmarket/client.go +++ b/binance/spot/websocketmarket/client.go @@ -41,7 +41,9 @@ type SpotMarketStreamClient struct { // logger logger *slog.Logger - ctx context.Context + stopCtx context.Context + cancel context.CancelFunc + conn *websocket.Conn mu sync.RWMutex isConnected bool @@ -56,13 +58,14 @@ type SpotMarketStreamClient struct { } type SpotMarketStreamCfg struct { - BaseURL string `validate:"required"` - Debug bool - // Logger + Debug bool + BaseURL string `validate:"required"` + AutoReconnect bool `validate:"required"` + Logger *slog.Logger } -func NewSpotMarketStreamClient(ctx context.Context, cfg *SpotMarketStreamCfg) (*SpotMarketStreamClient, error) { +func NewSpotMarketStreamClient(cfg *SpotMarketStreamCfg) (*SpotMarketStreamClient, error) { if err := validator.New().Struct(cfg); err != nil { return nil, err } @@ -72,8 +75,7 @@ func NewSpotMarketStreamClient(ctx context.Context, cfg *SpotMarketStreamCfg) (* debug: cfg.Debug, logger: cfg.Logger, - ctx: ctx, - autoReconnect: true, + autoReconnect: cfg.AutoReconnect, subscriptions: cmap.New[struct{}](), emitter: emission.NewEmitter(), @@ -83,12 +85,33 @@ func NewSpotMarketStreamClient(ctx context.Context, cfg *SpotMarketStreamCfg) (* cli.logger = slog.Default() } - err := cli.start() + return cli, nil +} + +func (m *SpotMarketStreamClient) Open() error { + if m.stopCtx != nil { + return fmt.Errorf("%s: ws is already open", logPrefix) + } + + m.stopCtx, m.cancel = context.WithCancel(context.Background()) + + err := m.start() if err != nil { - return nil, err + return err } - return cli, nil + return nil +} + +func (m *SpotMarketStreamClient) Close() error { + if m.stopCtx == nil { + return fmt.Errorf("%s: ws is not open", logPrefix) + } + + m.cancel() + m.stopCtx = nil + + return nil } func (m *SpotMarketStreamClient) start() error { @@ -99,7 +122,7 @@ func (m *SpotMarketStreamClient) start() error { for i := 0; i < MaxTryTimes; i++ { conn, _, err := m.connect() if err != nil { - m.logger.Info(fmt.Sprintf("connect error, times(%v), error: %s", i, err.Error())) + m.logger.Info(fmt.Sprintf("%s: connect error, times(%v), error: %s", logPrefix, i, err.Error())) tm := (i + 1) * 5 time.Sleep(time.Duration(tm) * time.Second) continue @@ -111,6 +134,8 @@ func (m *SpotMarketStreamClient) start() error { return errors.New("connect failed") } + m.logger.Info(fmt.Sprintf("%s: connect success, base_url: %s", logPrefix, m.baseURL)) + m.setIsConnected(true) m.resubscribe() @@ -141,15 +166,14 @@ func (m *SpotMarketStreamClient) reconnect() { m.setIsConnected(false) - m.logger.Info("disconnect, then reconnect...") - time.Sleep(1 * time.Second) select { - case <-m.ctx.Done(): - m.logger.Info(fmt.Sprintf("never reconnect, %s", m.ctx.Err())) + case <-m.stopCtx.Done(): + m.logger.Info(fmt.Sprintf("%s: reconnection exits", logPrefix)) return default: + m.logger.Info(fmt.Sprintf("%s: try to reconnect...", logPrefix)) m.start() } } @@ -185,24 +209,28 @@ func (m *SpotMarketStreamClient) IsConnected() bool { func (m *SpotMarketStreamClient) readMessages() { for { select { - case <-m.ctx.Done(): - m.logger.Info(fmt.Sprintf("context done, error: %s", m.ctx.Err().Error())) + case <-m.stopCtx.Done(): + m.logger.Info(fmt.Sprintf("%s: ready to close...", logPrefix)) if err := m.close(); err != nil { - m.logger.Info(fmt.Sprintf("websocket connection closed error, %s", err.Error())) + m.logger.Error(fmt.Sprintf("%s: connection closed error, %s", logPrefix, err.Error())) + return } + m.logger.Info(fmt.Sprintf("%s: connection closed success", logPrefix)) return default: var msg utils.AnyMessage err := m.conn.ReadJSON(&msg) if err != nil { - m.logger.Info(fmt.Sprintf("read object error, %s", err)) + m.logger.Info(fmt.Sprintf("%s: read message error, %s", logPrefix, err)) if err := m.close(); err != nil { - m.logger.Info(fmt.Sprintf("websocket connection closed error, %s", err.Error())) + m.logger.Error(fmt.Sprintf("%s: connection closed error, %s", logPrefix, err.Error())) + return } + m.logger.Info(fmt.Sprintf("%s: connection closed success", logPrefix)) return } @@ -212,7 +240,7 @@ func (m *SpotMarketStreamClient) readMessages() { case msg.SubscribedMessage != nil: err := m.handle(msg.SubscribedMessage) if err != nil { - m.logger.Info(fmt.Sprintf("handle message error: %s", err.Error())) + m.logger.Info(fmt.Sprintf("%s: handle message error: %s", logPrefix, err.Error())) } } } diff --git a/binance/spot/websocketmarket/client_test.go b/binance/spot/websocketmarket/client_test.go index b91bcdd..bb71a68 100644 --- a/binance/spot/websocketmarket/client_test.go +++ b/binance/spot/websocketmarket/client_test.go @@ -15,21 +15,23 @@ * limitations under the License. */ -package websocketmarket +package websocketmarket_test import ( - "context" "fmt" "testing" + "time" + spotws "github.com/linstohu/nexapi/binance/spot/websocketmarket" "github.com/linstohu/nexapi/binance/spot/websocketmarket/types" "github.com/stretchr/testify/assert" ) -func testNewSpotMarketStreamClient(ctx context.Context, t *testing.T) *SpotMarketStreamClient { - cli, err := NewSpotMarketStreamClient(ctx, &SpotMarketStreamCfg{ - BaseURL: SpotMarketStreamBaseURL, - Debug: true, +func testNewSpotMarketStreamClient(t *testing.T) *spotws.SpotMarketStreamClient { + cli, err := spotws.NewSpotMarketStreamClient(&spotws.SpotMarketStreamCfg{ + Debug: false, + BaseURL: spotws.SpotMarketStreamBaseURL, + AutoReconnect: true, }) if err != nil { @@ -40,10 +42,9 @@ func testNewSpotMarketStreamClient(ctx context.Context, t *testing.T) *SpotMarke } func TestSubscribeAggTrade(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewSpotMarketStreamClient(ctx, t) + cli := testNewSpotMarketStreamClient(t) + err := cli.Open() + assert.Nil(t, err) topic, err := cli.GetAggTradeTopic("btcusdt") assert.Nil(t, err) @@ -60,5 +61,9 @@ func TestSubscribeAggTrade(t *testing.T) { cli.Subscribe([]string{topic}) + time.Sleep(10 * time.Second) + + cli.Close() + select {} } diff --git a/binance/spot/websocketmarket/subscriptions.go b/binance/spot/websocketmarket/subscriptions.go index cb550ae..c24180f 100644 --- a/binance/spot/websocketmarket/subscriptions.go +++ b/binance/spot/websocketmarket/subscriptions.go @@ -36,7 +36,7 @@ func (m *SpotMarketStreamClient) UnSubscribe(topics []string) error { func (m *SpotMarketStreamClient) handle(msg *utils.SubscribedMessage) error { if m.debug { - m.logger.Info(fmt.Sprintf("subscribed message, stream: %s", msg.Stream)) + m.logger.Info(fmt.Sprintf("%s, subscribed message, stream: %s", logPrefix, msg.Stream)) } switch { diff --git a/binance/spot/websocketmarket/vars.go b/binance/spot/websocketmarket/vars.go index 0cffeac..5e384b5 100644 --- a/binance/spot/websocketmarket/vars.go +++ b/binance/spot/websocketmarket/vars.go @@ -22,6 +22,10 @@ var ( CombinedStreamRouter = "/stream" ) +const ( + logPrefix = "binance::spot::websocketmarket" +) + const ( MaxTryTimes = 5 ) diff --git a/binance/usdmfutures/websocketmarket/client.go b/binance/usdmfutures/websocketmarket/client.go index 50bb307..f60a745 100644 --- a/binance/usdmfutures/websocketmarket/client.go +++ b/binance/usdmfutures/websocketmarket/client.go @@ -41,7 +41,9 @@ type USDMarginedMarketStreamClient struct { // logger logger *slog.Logger - ctx context.Context + stopCtx context.Context + cancel context.CancelFunc + conn *websocket.Conn mu sync.RWMutex isConnected bool @@ -56,13 +58,14 @@ type USDMarginedMarketStreamClient struct { } type USDMarginedMarketStreamCfg struct { - BaseURL string `validate:"required"` - Debug bool - // Logger + Debug bool + BaseURL string `validate:"required"` + AutoReconnect bool `validate:"required"` + Logger *slog.Logger } -func NewMarketStreamClient(ctx context.Context, cfg *USDMarginedMarketStreamCfg) (*USDMarginedMarketStreamClient, error) { +func NewMarketStreamClient(cfg *USDMarginedMarketStreamCfg) (*USDMarginedMarketStreamClient, error) { if err := validator.New().Struct(cfg); err != nil { return nil, err } @@ -72,8 +75,7 @@ func NewMarketStreamClient(ctx context.Context, cfg *USDMarginedMarketStreamCfg) debug: cfg.Debug, logger: cfg.Logger, - ctx: ctx, - autoReconnect: true, + autoReconnect: cfg.AutoReconnect, subscriptions: cmap.New[struct{}](), emitter: emission.NewEmitter(), @@ -83,12 +85,33 @@ func NewMarketStreamClient(ctx context.Context, cfg *USDMarginedMarketStreamCfg) cli.logger = slog.Default() } - err := cli.start() + return cli, nil +} + +func (u *USDMarginedMarketStreamClient) Open() error { + if u.stopCtx != nil { + return fmt.Errorf("%s: ws is already open", logPrefix) + } + + u.stopCtx, u.cancel = context.WithCancel(context.Background()) + + err := u.start() if err != nil { - return nil, err + return err } - return cli, nil + return nil +} + +func (u *USDMarginedMarketStreamClient) Close() error { + if u.stopCtx == nil { + return fmt.Errorf("%s: ws is not open", logPrefix) + } + + u.cancel() + u.stopCtx = nil + + return nil } func (u *USDMarginedMarketStreamClient) start() error { @@ -99,7 +122,7 @@ func (u *USDMarginedMarketStreamClient) start() error { for i := 0; i < MaxTryTimes; i++ { conn, _, err := u.connect() if err != nil { - u.logger.Info(fmt.Sprintf("connect error, times(%v), error: %s", i, err.Error())) + u.logger.Info(fmt.Sprintf("%s: connect error, times(%v), error: %s", logPrefix, i, err.Error())) tm := (i + 1) * 5 time.Sleep(time.Duration(tm) * time.Second) continue @@ -111,6 +134,8 @@ func (u *USDMarginedMarketStreamClient) start() error { return errors.New("connect failed") } + u.logger.Info(fmt.Sprintf("%s: connect success, base_url: %s", logPrefix, u.baseURL)) + u.setIsConnected(true) u.resubscribe() @@ -141,15 +166,14 @@ func (u *USDMarginedMarketStreamClient) reconnect() { u.setIsConnected(false) - u.logger.Info("disconnect, then reconnect...") - time.Sleep(1 * time.Second) select { - case <-u.ctx.Done(): - u.logger.Info(fmt.Sprintf("never reconnect, %s", u.ctx.Err())) + case <-u.stopCtx.Done(): + u.logger.Info(fmt.Sprintf("%s: reconnection exits", logPrefix)) return default: + u.logger.Info(fmt.Sprintf("%s: try to reconnect...", logPrefix)) u.start() } } @@ -185,24 +209,28 @@ func (u *USDMarginedMarketStreamClient) IsConnected() bool { func (u *USDMarginedMarketStreamClient) readMessages() { for { select { - case <-u.ctx.Done(): - u.logger.Info(fmt.Sprintf("context done, error: %s", u.ctx.Err().Error())) + case <-u.stopCtx.Done(): + u.logger.Info(fmt.Sprintf("%s: ready to close...", logPrefix)) if err := u.close(); err != nil { - u.logger.Info(fmt.Sprintf("websocket connection closed error, %s", err.Error())) + u.logger.Error(fmt.Sprintf("%s: connection closed error, %s", logPrefix, err.Error())) + return } + u.logger.Info(fmt.Sprintf("%s: connection closed success", logPrefix)) return default: var msg utils.AnyMessage err := u.conn.ReadJSON(&msg) if err != nil { - u.logger.Info(fmt.Sprintf("read object error, %s", err)) + u.logger.Info(fmt.Sprintf("%s: read message error, %s", logPrefix, err)) if err := u.close(); err != nil { - u.logger.Info(fmt.Sprintf("websocket connection closed error, %s", err.Error())) + u.logger.Error(fmt.Sprintf("%s: connection closed error, %s", logPrefix, err.Error())) + return } + u.logger.Info(fmt.Sprintf("%s: connection closed success", logPrefix)) return } @@ -212,7 +240,7 @@ func (u *USDMarginedMarketStreamClient) readMessages() { case msg.SubscribedMessage != nil: err := u.handle(msg.SubscribedMessage) if err != nil { - u.logger.Info(fmt.Sprintf("handle message error: %s", err.Error())) + u.logger.Info(fmt.Sprintf("%s: handle message error: %s", logPrefix, err.Error())) } } } diff --git a/binance/usdmfutures/websocketmarket/client_test.go b/binance/usdmfutures/websocketmarket/client_test.go index 4ea294e..004e6d3 100644 --- a/binance/usdmfutures/websocketmarket/client_test.go +++ b/binance/usdmfutures/websocketmarket/client_test.go @@ -15,23 +15,24 @@ * limitations under the License. */ -package websocketmarket +package websocketmarket_test import ( - "context" "fmt" "testing" "time" spottypes "github.com/linstohu/nexapi/binance/spot/websocketmarket/types" + usdmws "github.com/linstohu/nexapi/binance/usdmfutures/websocketmarket" "github.com/linstohu/nexapi/binance/usdmfutures/websocketmarket/types" "github.com/stretchr/testify/assert" ) -func testNewMarketStreamClient(ctx context.Context, t *testing.T) *USDMarginedMarketStreamClient { - cli, err := NewMarketStreamClient(ctx, &USDMarginedMarketStreamCfg{ - BaseURL: USDMarginedMarketStreamBaseURL, - Debug: false, +func testNewMarketStreamClient(t *testing.T) *usdmws.USDMarginedMarketStreamClient { + cli, err := usdmws.NewMarketStreamClient(&usdmws.USDMarginedMarketStreamCfg{ + Debug: false, + BaseURL: usdmws.USDMarginedMarketStreamBaseURL, + AutoReconnect: true, }) if err != nil { @@ -42,10 +43,9 @@ func testNewMarketStreamClient(ctx context.Context, t *testing.T) *USDMarginedMa } func TestSubscribeAggTrade(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) + err := cli.Open() + assert.Nil(t, err) topic, err := cli.GetAggTradeTopic("btcusdt") assert.Nil(t, err) @@ -66,16 +66,19 @@ func TestSubscribeAggTrade(t *testing.T) { cli.UnSubscribe([]string{topic}) + time.Sleep(1 * time.Second) + + cli.Close() + select {} } func TestSubscribeMarkPrice(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) + err := cli.Open() + assert.Nil(t, err) - topic, err := cli.GetMarkPriceTopic(&MarkPriceTopicParam{ + topic, err := cli.GetMarkPriceTopic(&usdmws.MarkPriceTopicParam{ Symbol: "btcusdt", UpdateSpeed: "1s", }) @@ -97,12 +100,11 @@ func TestSubscribeMarkPrice(t *testing.T) { } func TestSubscribeAllMarkPrice(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) + err := cli.Open() + assert.Nil(t, err) - topic, err := cli.GetAllMarketPriceTopic(&AllMarkPriceTopicParam{ + topic, err := cli.GetAllMarketPriceTopic(&usdmws.AllMarkPriceTopicParam{ UpdateSpeed: "1s", }) assert.Nil(t, err) @@ -125,12 +127,11 @@ func TestSubscribeAllMarkPrice(t *testing.T) { } func TestSubscribeKline(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) + err := cli.Open() + assert.Nil(t, err) - topic, err := cli.GetKlineTopic(&KlineTopicParam{ + topic, err := cli.GetKlineTopic(&usdmws.KlineTopicParam{ Symbol: "btcusdt", Interval: "1m", }) @@ -152,10 +153,9 @@ func TestSubscribeKline(t *testing.T) { } func TestSubscribeMiniTicker(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) + err := cli.Open() + assert.Nil(t, err) topic, err := cli.GetMiniTickerTopic("btcusdt") assert.Nil(t, err) @@ -176,10 +176,9 @@ func TestSubscribeMiniTicker(t *testing.T) { } func TestSubscribeAllMiniTicker(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) + err := cli.Open() + assert.Nil(t, err) topic, err := cli.GetAllMarketMiniTickersTopic() assert.Nil(t, err) @@ -202,10 +201,9 @@ func TestSubscribeAllMiniTicker(t *testing.T) { } func TestSubscribeTicker(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) + err := cli.Open() + assert.Nil(t, err) topic, err := cli.GetTickerTopic("btcusdt") assert.Nil(t, err) @@ -226,10 +224,9 @@ func TestSubscribeTicker(t *testing.T) { } func TestSubscribeAllTicker(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) + err := cli.Open() + assert.Nil(t, err) topic, err := cli.GetAllMarketTickersTopic() assert.Nil(t, err) @@ -252,10 +249,9 @@ func TestSubscribeAllTicker(t *testing.T) { } func TestSubscribeBookTicker(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) + err := cli.Open() + assert.Nil(t, err) topic, err := cli.GetBookTickerTopic("btcusdt") assert.Nil(t, err) @@ -276,10 +272,9 @@ func TestSubscribeBookTicker(t *testing.T) { } func TestSubscribeAllBookTickers(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) + err := cli.Open() + assert.Nil(t, err) topic, err := cli.GetAllBookTickersTopic() assert.Nil(t, err) @@ -300,10 +295,9 @@ func TestSubscribeAllBookTickers(t *testing.T) { } func TestSubscribeLiquidationOrder(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) + err := cli.Open() + assert.Nil(t, err) topic, err := cli.GetLiquidationOrderTopic("btcusdt") assert.Nil(t, err) @@ -326,10 +320,9 @@ func TestSubscribeLiquidationOrder(t *testing.T) { } func TestSubscribeAllLiquidationOrders(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) + err := cli.Open() + assert.Nil(t, err) topic, err := cli.GetAllLiquidationOrdersTopic() assert.Nil(t, err) @@ -350,12 +343,11 @@ func TestSubscribeAllLiquidationOrders(t *testing.T) { } func TestSubscribeBookDepth(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) + err := cli.Open() + assert.Nil(t, err) - topic, err := cli.GetBookDepthTopic(&BookDepthTopicParam{ + topic, err := cli.GetBookDepthTopic(&usdmws.BookDepthTopicParam{ Symbol: "btcusdt", Level: 5, UpdateSpeed: "500ms", @@ -380,12 +372,11 @@ func TestSubscribeBookDepth(t *testing.T) { } func TestSubscribeBookDiffDepth(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketStreamClient(ctx, t) + cli := testNewMarketStreamClient(t) + err := cli.Open() + assert.Nil(t, err) - topic, err := cli.GetBookDiffDepthTopic(&BookDiffDepthTopicParam{ + topic, err := cli.GetBookDiffDepthTopic(&usdmws.BookDiffDepthTopicParam{ Symbol: "btcusdt", UpdateSpeed: "500ms", }) diff --git a/binance/usdmfutures/websocketmarket/subscriptions.go b/binance/usdmfutures/websocketmarket/subscriptions.go index e09baad..0a42e97 100644 --- a/binance/usdmfutures/websocketmarket/subscriptions.go +++ b/binance/usdmfutures/websocketmarket/subscriptions.go @@ -37,7 +37,7 @@ func (u *USDMarginedMarketStreamClient) UnSubscribe(topics []string) error { func (u *USDMarginedMarketStreamClient) handle(msg *utils.SubscribedMessage) error { if u.debug { - u.logger.Info(fmt.Sprintf("subscribed message, stream: %s", msg.Stream)) + u.logger.Info(fmt.Sprintf("%s: subscribed message, stream: %s", logPrefix, msg.Stream)) } switch { diff --git a/binance/usdmfutures/websocketmarket/vars.go b/binance/usdmfutures/websocketmarket/vars.go index 6c0f593..50564e0 100644 --- a/binance/usdmfutures/websocketmarket/vars.go +++ b/binance/usdmfutures/websocketmarket/vars.go @@ -22,6 +22,10 @@ var ( CombinedStreamRouter = "/stream" ) +const ( + logPrefix = "binance::usdm::websocketmarket" +) + const ( MaxTryTimes = 5 ) diff --git a/htx/spot/accountws/client.go b/htx/spot/accountws/client.go index e52ea2f..79c1c10 100644 --- a/htx/spot/accountws/client.go +++ b/htx/spot/accountws/client.go @@ -47,7 +47,9 @@ type AccountWsClient struct { key, secret string - ctx context.Context + stopCtx context.Context + cancel context.CancelFunc + conn *websocket.Conn mu sync.RWMutex isConnected bool @@ -62,15 +64,17 @@ type AccountWsClient struct { } type AccountWsClientCfg struct { - Debug bool - // Logger - Logger *slog.Logger - BaseURL string `validate:"required"` - Key string `validate:"required"` - Secret string `validate:"required"` + Debug bool + BaseURL string `validate:"required"` + AutoReconnect bool `validate:"required"` + + Key string `validate:"required"` + Secret string `validate:"required"` + + Logger *slog.Logger } -func NewAccountWsClient(ctx context.Context, cfg *AccountWsClientCfg) (*AccountWsClient, error) { +func NewAccountWsClient(cfg *AccountWsClientCfg) (*AccountWsClient, error) { if err := validator.New().Struct(cfg); err != nil { return nil, err } @@ -89,8 +93,7 @@ func NewAccountWsClient(ctx context.Context, cfg *AccountWsClientCfg) (*AccountW key: cfg.Key, secret: cfg.Secret, - ctx: ctx, - autoReconnect: true, + autoReconnect: cfg.AutoReconnect, subscriptions: cmap.New[struct{}](), emitter: emission.NewEmitter(), @@ -100,14 +103,35 @@ func NewAccountWsClient(ctx context.Context, cfg *AccountWsClientCfg) (*AccountW cli.logger = slog.Default() } - err = cli.start() + time.Sleep(100 * time.Millisecond) + + return cli, nil +} + +func (m *AccountWsClient) Open() error { + if m.stopCtx != nil { + return fmt.Errorf("%s: ws is already open", logPrefix) + } + + m.stopCtx, m.cancel = context.WithCancel(context.Background()) + + err := m.start() if err != nil { - return nil, err + return err } - time.Sleep(100 * time.Millisecond) + return nil +} - return cli, nil +func (m *AccountWsClient) Close() error { + if m.stopCtx == nil { + return fmt.Errorf("%s: ws is not open", logPrefix) + } + + m.cancel() + m.stopCtx = nil + + return nil } func (m *AccountWsClient) start() error { @@ -118,7 +142,7 @@ func (m *AccountWsClient) start() error { for i := 0; i < MaxTryTimes; i++ { conn, _, err := m.connect() if err != nil { - m.logger.Info(fmt.Sprintf("connect error, times(%v), error: %s", i, err.Error())) + m.logger.Info(fmt.Sprintf("%s: connect error, times(%v), error: %s", logPrefix, i, err.Error())) tm := (i + 1) * 5 time.Sleep(time.Duration(tm) * time.Second) continue @@ -130,6 +154,8 @@ func (m *AccountWsClient) start() error { return errors.New("connect failed") } + m.logger.Info(fmt.Sprintf("%s: connect success, base_url: %s", logPrefix, m.baseURL)) + m.setIsConnected(true) m.resubscribe() @@ -162,15 +188,14 @@ func (m *AccountWsClient) reconnect() { m.setIsConnected(false) - m.logger.Info("disconnect, then reconnect...") - time.Sleep(1 * time.Second) select { - case <-m.ctx.Done(): - m.logger.Info(fmt.Sprintf("never reconnect, %s", m.ctx.Err())) + case <-m.stopCtx.Done(): + m.logger.Info(fmt.Sprintf("%s: reconnection exits", logPrefix)) return default: + m.logger.Info(fmt.Sprintf("%s: try to reconnect...", logPrefix)) m.start() } } @@ -252,25 +277,30 @@ func (m *AccountWsClient) auth() error { func (m *AccountWsClient) readMessages() { for { select { - case <-m.ctx.Done(): - m.logger.Info(fmt.Sprintf("context done, error: %s", m.ctx.Err().Error())) + case <-m.stopCtx.Done(): + m.logger.Info(fmt.Sprintf("%s: ready to close...", logPrefix)) if err := m.close(); err != nil { - m.logger.Info(fmt.Sprintf("websocket connection closed error, %s", err.Error())) + m.logger.Error(fmt.Sprintf("%s: connection closed error, %s", logPrefix, err.Error())) + return } + m.logger.Info(fmt.Sprintf("%s: connection closed success", logPrefix)) return default: var msg Message err := m.conn.ReadJSON(&msg) if err != nil { - m.logger.Error(fmt.Sprintf("read object error, %s", err)) + m.logger.Info(fmt.Sprintf("%s: read message error, %s", logPrefix, err)) + m.logger.Info(fmt.Sprintf("%s: ready to close...", logPrefix)) if err := m.close(); err != nil { - m.logger.Error(fmt.Sprintf("websocket connection closed error, %s", err.Error())) + m.logger.Error(fmt.Sprintf("%s: connection closed error, %s", logPrefix, err.Error())) + return } + m.logger.Info(fmt.Sprintf("%s: connection closed success", logPrefix)) return } @@ -281,27 +311,30 @@ func (m *AccountWsClient) readMessages() { Data: msg.Data, }) if err != nil { - m.logger.Error(fmt.Sprintf("handle ping error: %s", err.Error())) + m.logger.Error(fmt.Sprintf("%s: handle ping error: %s", logPrefix, err.Error())) } case msg.Action == SUB: case msg.Action == REQ: if msg.Channel == "auth" { if msg.Code != 200 { - m.logger.Error(fmt.Sprintf("auth websocket error, action: %s, ch: %s, code: %v", msg.Action, msg.Channel, msg.Code)) + m.logger.Info(fmt.Sprintf("%s: auth websocket error, action: %s, ch: %s, code: %v", logPrefix, msg.Action, msg.Channel, msg.Code)) + m.logger.Info(fmt.Sprintf("%s: ready to close...", logPrefix)) if err := m.close(); err != nil { - m.logger.Error(fmt.Sprintf("websocket connection closed error, %s", err.Error())) + m.logger.Error(fmt.Sprintf("%s: connection closed error, %s", logPrefix, err.Error())) + return } + m.logger.Info(fmt.Sprintf("%s: connection closed success", logPrefix)) return } else { - m.logger.Info(fmt.Sprintf("auth websocket success, action: %s, ch: %s, code: %v", msg.Action, msg.Channel, msg.Code)) + m.logger.Info(fmt.Sprintf("%s: auth websocket success, action: %s, ch: %s, code: %v", logPrefix, msg.Action, msg.Channel, msg.Code)) } } case msg.Action == PUSH: err := m.handle(&msg) if err != nil { - m.logger.Error(fmt.Sprintf("handle message error: %s", err.Error())) + m.logger.Error(fmt.Sprintf("%s: handle message error: %s", logPrefix, err.Error())) } } } diff --git a/htx/spot/accountws/client_test.go b/htx/spot/accountws/client_test.go index 61dc0d1..ddb72c1 100644 --- a/htx/spot/accountws/client_test.go +++ b/htx/spot/accountws/client_test.go @@ -18,7 +18,6 @@ package accountws import ( - "context" "fmt" "os" "testing" @@ -27,12 +26,13 @@ import ( "github.com/stretchr/testify/assert" ) -func testNewAccountWsClient(ctx context.Context, t *testing.T, url string) *AccountWsClient { - cli, err := NewAccountWsClient(ctx, &AccountWsClientCfg{ - BaseURL: url, - Debug: true, - Key: os.Getenv("HTX_KEY"), - Secret: os.Getenv("HTX_SECRET"), +func testNewAccountWsClient(t *testing.T, url string) *AccountWsClient { + cli, err := NewAccountWsClient(&AccountWsClientCfg{ + Debug: true, + BaseURL: url, + AutoReconnect: true, + Key: os.Getenv("HTX_KEY"), + Secret: os.Getenv("HTX_SECRET"), }) if err != nil { @@ -43,10 +43,7 @@ func testNewAccountWsClient(ctx context.Context, t *testing.T, url string) *Acco } func TestSubscribeAccountUpdate(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewAccountWsClient(ctx, t, GlobalWsBaseURL) + cli := testNewAccountWsClient(t, GlobalWsBaseURL) topic, err := cli.GetAccountUpdateTopic(&AccountUpdateTopicParam{ Mode: 2, diff --git a/htx/spot/accountws/subscriptions.go b/htx/spot/accountws/subscriptions.go index 6c44afd..1124208 100644 --- a/htx/spot/accountws/subscriptions.go +++ b/htx/spot/accountws/subscriptions.go @@ -35,7 +35,7 @@ func (m *AccountWsClient) UnSubscribe(topic string) error { func (m *AccountWsClient) handle(msg *Message) error { if m.debug { - m.logger.Info(fmt.Sprintf("subscribed message, channel: %s", msg.Channel)) + m.logger.Info(fmt.Sprintf("%s: subscribed message, channel: %s", logPrefix, msg.Channel)) } switch { diff --git a/htx/spot/accountws/vars.go b/htx/spot/accountws/vars.go index 4e4709e..5b97bb8 100644 --- a/htx/spot/accountws/vars.go +++ b/htx/spot/accountws/vars.go @@ -22,6 +22,10 @@ const ( GlobalWsBaseURLForAWS = "wss://api-aws.huobi.pro/ws/v2" ) +const ( + logPrefix = "htx::spot::accountws" +) + const ( MaxTryTimes = 5 ) diff --git a/htx/spot/marketws/client.go b/htx/spot/marketws/client.go index 7c46bfd..181aee8 100644 --- a/htx/spot/marketws/client.go +++ b/htx/spot/marketws/client.go @@ -43,7 +43,9 @@ type MarketWsClient struct { // logger logger *slog.Logger - ctx context.Context + stopCtx context.Context + cancel context.CancelFunc + conn *websocket.Conn mu sync.RWMutex isConnected bool @@ -58,13 +60,14 @@ type MarketWsClient struct { } type MarketWsClientCfg struct { - BaseURL string `validate:"required"` - Debug bool - // Logger + Debug bool + BaseURL string `validate:"required"` + AutoReconnect bool `validate:"required"` + Logger *slog.Logger } -func NewMarketWsClient(ctx context.Context, cfg *MarketWsClientCfg) (*MarketWsClient, error) { +func NewMarketWsClient(cfg *MarketWsClientCfg) (*MarketWsClient, error) { if err := validator.New().Struct(cfg); err != nil { return nil, err } @@ -74,8 +77,7 @@ func NewMarketWsClient(ctx context.Context, cfg *MarketWsClientCfg) (*MarketWsCl debug: cfg.Debug, logger: cfg.Logger, - ctx: ctx, - autoReconnect: true, + autoReconnect: cfg.AutoReconnect, subscriptions: cmap.New[struct{}](), emitter: emission.NewEmitter(), @@ -85,12 +87,33 @@ func NewMarketWsClient(ctx context.Context, cfg *MarketWsClientCfg) (*MarketWsCl cli.logger = slog.Default() } - err := cli.start() + return cli, nil +} + +func (m *MarketWsClient) Open() error { + if m.stopCtx != nil { + return fmt.Errorf("%s: ws is already open", logPrefix) + } + + m.stopCtx, m.cancel = context.WithCancel(context.Background()) + + err := m.start() if err != nil { - return nil, err + return err } - return cli, nil + return nil +} + +func (m *MarketWsClient) Close() error { + if m.stopCtx == nil { + return fmt.Errorf("%s: ws is not open", logPrefix) + } + + m.cancel() + m.stopCtx = nil + + return nil } func (m *MarketWsClient) start() error { @@ -101,7 +124,7 @@ func (m *MarketWsClient) start() error { for i := 0; i < MaxTryTimes; i++ { conn, _, err := m.connect() if err != nil { - m.logger.Info(fmt.Sprintf("connect error, times(%v), error: %s", i, err.Error())) + m.logger.Info(fmt.Sprintf("%s: connect error, times(%v), error: %s", logPrefix, i, err.Error())) tm := (i + 1) * 5 time.Sleep(time.Duration(tm) * time.Second) continue @@ -113,6 +136,8 @@ func (m *MarketWsClient) start() error { return errors.New("connect failed") } + m.logger.Info(fmt.Sprintf("%s: connect success, base_url: %s", logPrefix, m.baseURL)) + m.setIsConnected(true) m.resubscribe() @@ -143,15 +168,14 @@ func (m *MarketWsClient) reconnect() { m.setIsConnected(false) - m.logger.Info("disconnect, then reconnect...") - time.Sleep(1 * time.Second) select { - case <-m.ctx.Done(): - m.logger.Info(fmt.Sprintf("never reconnect, %s", m.ctx.Err())) + case <-m.stopCtx.Done(): + m.logger.Info(fmt.Sprintf("%s: reconnection exits", logPrefix)) return default: + m.logger.Info(fmt.Sprintf("%s: try to reconnect...", logPrefix)) m.start() } } @@ -187,18 +211,20 @@ func (m *MarketWsClient) IsConnected() bool { func (m *MarketWsClient) readMessages() { for { select { - case <-m.ctx.Done(): - m.logger.Info(fmt.Sprintf("context done, error: %s", m.ctx.Err().Error())) + case <-m.stopCtx.Done(): + m.logger.Info(fmt.Sprintf("%s: ready to close...", logPrefix)) if err := m.close(); err != nil { - m.logger.Info(fmt.Sprintf("websocket connection closed error, %s", err.Error())) + m.logger.Error(fmt.Sprintf("%s: connection closed error, %s", logPrefix, err.Error())) + return } + m.logger.Info(fmt.Sprintf("%s: connection closed success", logPrefix)) return default: msgType, buf, err := m.conn.ReadMessage() if err != nil { - m.logger.Error(fmt.Sprintf("read message error, %s", err)) + m.logger.Info(fmt.Sprintf("%s: read message error, %s", logPrefix, err)) time.Sleep(TimerIntervalSecond * time.Second) continue } @@ -207,12 +233,15 @@ func (m *MarketWsClient) readMessages() { if msgType == websocket.BinaryMessage { message, err := htxutils.GZipDecompress(buf) if err != nil { - m.logger.Error(fmt.Sprintf("ungzip data error: %s", err)) + m.logger.Info(fmt.Sprintf("%s: ungzip data error: %s", logPrefix, err)) + m.logger.Info(fmt.Sprintf("%s: ready to close...", logPrefix)) if err := m.close(); err != nil { - m.logger.Error(fmt.Sprintf("websocket connection closed error, %s", err.Error())) + m.logger.Error(fmt.Sprintf("%s: connection closed error, %s", logPrefix, err.Error())) + return } + m.logger.Info(fmt.Sprintf("%s: connection closed success", logPrefix)) return } @@ -220,12 +249,15 @@ func (m *MarketWsClient) readMessages() { err = json.Unmarshal([]byte(message), &msg) if err != nil { - m.logger.Error(fmt.Sprintf("read object error, %s", err)) + m.logger.Info(fmt.Sprintf("%s: read object error, %s", logPrefix, err)) + m.logger.Info(fmt.Sprintf("%s: ready to close...", logPrefix)) if err := m.close(); err != nil { - m.logger.Error(fmt.Sprintf("websocket connection closed error, %s", err.Error())) + m.logger.Error(fmt.Sprintf("%s: connection closed error, %s", logPrefix, err.Error())) + return } + m.logger.Info(fmt.Sprintf("%s: connection closed success", logPrefix)) return } @@ -235,14 +267,14 @@ func (m *MarketWsClient) readMessages() { Pong: msg.Ping.Ping, }) if err != nil { - m.logger.Error(fmt.Sprintf("handle ping error: %s", err.Error())) + m.logger.Error(fmt.Sprintf("%s: handle ping error: %s", logPrefix, err.Error())) } case msg.Response != nil: // todo case msg.SubscribedMessage != nil: err := m.handle(msg.SubscribedMessage) if err != nil { - m.logger.Error(fmt.Sprintf("handle message error: %s", err.Error())) + m.logger.Error(fmt.Sprintf("%s: handle message error: %s", logPrefix, err.Error())) } } } diff --git a/htx/spot/marketws/client_test.go b/htx/spot/marketws/client_test.go index 1032188..5df9b8c 100644 --- a/htx/spot/marketws/client_test.go +++ b/htx/spot/marketws/client_test.go @@ -18,7 +18,6 @@ package marketws import ( - "context" "fmt" "testing" @@ -27,10 +26,11 @@ import ( "github.com/stretchr/testify/assert" ) -func testNewMarketWsClient(ctx context.Context, t *testing.T, url string) *MarketWsClient { - cli, err := NewMarketWsClient(ctx, &MarketWsClientCfg{ - BaseURL: url, - Debug: true, +func testNewMarketWsClient(t *testing.T, url string) *MarketWsClient { + cli, err := NewMarketWsClient(&MarketWsClientCfg{ + Debug: true, + BaseURL: url, + AutoReconnect: true, }) if err != nil { @@ -41,10 +41,7 @@ func testNewMarketWsClient(ctx context.Context, t *testing.T, url string) *Marke } func TestSubscribeKline(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketWsClient(ctx, t, GlobalWsBaseURL) + cli := testNewMarketWsClient(t, GlobalWsBaseURL) topic, err := cli.GetKlineTopic(&KlineTopicParam{ Symbol: "btcusdt", @@ -68,10 +65,7 @@ func TestSubscribeKline(t *testing.T) { } func TestSubscribeMBPUpdateDepth(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketWsClient(ctx, t, MBPWsBaseURL) + cli := testNewMarketWsClient(t, MBPWsBaseURL) topic, err := cli.GetMBPDepthUpdateTopic(&MBPDepthUpdateTopicParam{ Symbol: "btcusdt", @@ -100,10 +94,7 @@ func TestSubscribeMBPUpdateDepth(t *testing.T) { } func TestSubscribeMBPRefreshDepth(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketWsClient(ctx, t, GlobalWsBaseURL) + cli := testNewMarketWsClient(t, GlobalWsBaseURL) topic, err := cli.GetMBPRefreshDepthTopic(&MBPDepthRefreshTopicParam{ Symbol: "btcusdt", diff --git a/htx/spot/marketws/subscriptions.go b/htx/spot/marketws/subscriptions.go index 42c13a3..1e738ce 100644 --- a/htx/spot/marketws/subscriptions.go +++ b/htx/spot/marketws/subscriptions.go @@ -35,7 +35,7 @@ func (m *MarketWsClient) UnSubscribe(topic string) error { func (m *MarketWsClient) handle(msg *SubscribedMessage) error { if m.debug { - m.logger.Info(fmt.Sprintf("subscribed message, channel: %s", msg.Channel)) + m.logger.Info(fmt.Sprintf("%s: subscribed message, channel: %s", logPrefix, msg.Channel)) } if strings.Contains(msg.Channel, "mbp") { diff --git a/htx/spot/marketws/vars.go b/htx/spot/marketws/vars.go index a951eaf..d36f5b5 100644 --- a/htx/spot/marketws/vars.go +++ b/htx/spot/marketws/vars.go @@ -25,6 +25,10 @@ const ( MBPWsBaseURLForAWS = "wss://api-aws.huobi.pro/feed" ) +const ( + logPrefix = "htx::spot::marketws" +) + const ( MaxTryTimes = 5 diff --git a/htx/usdm/accountws/client.go b/htx/usdm/accountws/client.go index 7a84b49..f15eeab 100644 --- a/htx/usdm/accountws/client.go +++ b/htx/usdm/accountws/client.go @@ -49,7 +49,9 @@ type AccountWsClient struct { key, secret string - ctx context.Context + stopCtx context.Context + cancel context.CancelFunc + conn *websocket.Conn mu sync.RWMutex isConnected bool @@ -64,15 +66,17 @@ type AccountWsClient struct { } type AccountWsClientCfg struct { - Debug bool - // Logger - Logger *slog.Logger - BaseURL string `validate:"required"` - Key string `validate:"required"` - Secret string `validate:"required"` + Debug bool + BaseURL string `validate:"required"` + AutoReconnect bool `validate:"required"` + + Key string `validate:"required"` + Secret string `validate:"required"` + + Logger *slog.Logger } -func NewAccountWsClient(ctx context.Context, cfg *AccountWsClientCfg) (*AccountWsClient, error) { +func NewAccountWsClient(cfg *AccountWsClientCfg) (*AccountWsClient, error) { if err := validator.New().Struct(cfg); err != nil { return nil, err } @@ -92,8 +96,7 @@ func NewAccountWsClient(ctx context.Context, cfg *AccountWsClientCfg) (*AccountW key: cfg.Key, secret: cfg.Secret, - ctx: ctx, - autoReconnect: true, + autoReconnect: cfg.AutoReconnect, subscriptions: cmap.New[struct{}](), emitter: emission.NewEmitter(), @@ -103,14 +106,35 @@ func NewAccountWsClient(ctx context.Context, cfg *AccountWsClientCfg) (*AccountW cli.logger = slog.Default() } - err = cli.start() + time.Sleep(100 * time.Millisecond) + + return cli, nil +} + +func (m *AccountWsClient) Open() error { + if m.stopCtx != nil { + return fmt.Errorf("%s: ws is already open", logPrefix) + } + + m.stopCtx, m.cancel = context.WithCancel(context.Background()) + + err := m.start() if err != nil { - return nil, err + return err } - time.Sleep(100 * time.Millisecond) + return nil +} - return cli, nil +func (m *AccountWsClient) Close() error { + if m.stopCtx == nil { + return fmt.Errorf("%s: ws is not open", logPrefix) + } + + m.cancel() + m.stopCtx = nil + + return nil } func (m *AccountWsClient) start() error { @@ -121,7 +145,7 @@ func (m *AccountWsClient) start() error { for i := 0; i < MaxTryTimes; i++ { conn, _, err := m.connect() if err != nil { - m.logger.Info(fmt.Sprintf("connect error, times(%v), error: %s", i, err.Error())) + m.logger.Info(fmt.Sprintf("%s: connect error, times(%v), error: %s", logPrefix, i, err.Error())) tm := (i + 1) * 5 time.Sleep(time.Duration(tm) * time.Second) continue @@ -133,6 +157,8 @@ func (m *AccountWsClient) start() error { return errors.New("connect failed") } + m.logger.Info(fmt.Sprintf("%s: connect success, base_url: %s", logPrefix, m.baseURL)) + m.setIsConnected(true) m.resubscribe() @@ -165,15 +191,14 @@ func (m *AccountWsClient) reconnect() { m.setIsConnected(false) - m.logger.Info("disconnect, then reconnect...") - time.Sleep(1 * time.Second) select { - case <-m.ctx.Done(): - m.logger.Info(fmt.Sprintf("never reconnect, %s", m.ctx.Err())) + case <-m.stopCtx.Done(): + m.logger.Info(fmt.Sprintf("%s: reconnection exits", logPrefix)) return default: + m.logger.Info(fmt.Sprintf("%s: try to reconnect...", logPrefix)) m.start() } } @@ -252,18 +277,20 @@ func (m *AccountWsClient) auth() error { func (m *AccountWsClient) readMessages() { for { select { - case <-m.ctx.Done(): - m.logger.Info(fmt.Sprintf("context done, error: %s", m.ctx.Err().Error())) + case <-m.stopCtx.Done(): + m.logger.Info(fmt.Sprintf("%s: ready to close...", logPrefix)) if err := m.close(); err != nil { - m.logger.Info(fmt.Sprintf("websocket connection closed error, %s", err.Error())) + m.logger.Error(fmt.Sprintf("%s: connection closed error, %s", logPrefix, err.Error())) + return } + m.logger.Info(fmt.Sprintf("%s: connection closed success", logPrefix)) return default: msgType, buf, err := m.conn.ReadMessage() if err != nil { - m.logger.Error(fmt.Sprintf("read message error, %s", err)) + m.logger.Info(fmt.Sprintf("%s: read message error, %s", logPrefix, err)) time.Sleep(TimerIntervalSecond * time.Second) continue } @@ -272,12 +299,15 @@ func (m *AccountWsClient) readMessages() { if msgType == websocket.BinaryMessage { message, err := htxutils.GZipDecompress(buf) if err != nil { - m.logger.Error(fmt.Sprintf("ungzip data error: %s", err)) + m.logger.Info(fmt.Sprintf("%s: ungzip data error: %s", logPrefix, err)) + m.logger.Info(fmt.Sprintf("%s: ready to close...", logPrefix)) if err := m.close(); err != nil { - m.logger.Error(fmt.Sprintf("websocket connection closed error, %s", err.Error())) + m.logger.Error(fmt.Sprintf("%s: connection closed error, %s", logPrefix, err.Error())) + return } + m.logger.Info(fmt.Sprintf("%s: connection closed success", logPrefix)) return } @@ -285,12 +315,15 @@ func (m *AccountWsClient) readMessages() { err = json.Unmarshal([]byte(message), &msg) if err != nil { - m.logger.Error(fmt.Sprintf("read object error, %s", err)) + m.logger.Info(fmt.Sprintf("%s: read object error, %s", logPrefix, err)) + m.logger.Info(fmt.Sprintf("%s: ready to close...", logPrefix)) if err := m.close(); err != nil { - m.logger.Error(fmt.Sprintf("websocket connection closed error, %s", err.Error())) + m.logger.Error(fmt.Sprintf("%s: connection closed error, %s", logPrefix, err.Error())) + return } + m.logger.Info(fmt.Sprintf("%s: connection closed success", logPrefix)) return } @@ -301,28 +334,31 @@ func (m *AccountWsClient) readMessages() { Ts: msg.Ts, }) if err != nil { - m.logger.Error(fmt.Sprintf("handle ping error: %s", err.Error())) + m.logger.Error(fmt.Sprintf("%s: handle ping error: %s", logPrefix, err.Error())) } case msg.Operation == SUB: if msg.ErrCode != 0 { - m.logger.Error(fmt.Sprintf("sub websocket error, op: %s, topic: %s, err-code: %v, err-msg: %v", msg.Operation, msg.Topic, msg.ErrCode, msg.ErrMsg)) + m.logger.Error(fmt.Sprintf("%s: sub websocket error, op: %s, topic: %s, err-code: %v, err-msg: %v", logPrefix, msg.Operation, msg.Topic, msg.ErrCode, msg.ErrMsg)) } case msg.Operation == AUTH: if msg.ErrCode != 0 { - m.logger.Error(fmt.Sprintf("auth websocket error, op: %s, err-code: %v, err-msg: %v", msg.Operation, msg.ErrCode, msg.ErrMsg)) + m.logger.Info(fmt.Sprintf("%s: auth websocket error, op: %s, err-code: %v, err-msg: %v", logPrefix, msg.Operation, msg.ErrCode, msg.ErrMsg)) + m.logger.Info(fmt.Sprintf("%s: ready to close...", logPrefix)) if err := m.close(); err != nil { - m.logger.Error(fmt.Sprintf("websocket connection closed error, %s", err.Error())) + m.logger.Error(fmt.Sprintf("%s: connection closed error, %s", logPrefix, err.Error())) + return } + m.logger.Info(fmt.Sprintf("%s: connection closed success", logPrefix)) return } else { - m.logger.Info(fmt.Sprintf("auth websocket success, op: %s, err-code: %v", msg.Operation, msg.ErrCode)) + m.logger.Info(fmt.Sprintf("%s: auth websocket success, op: %s, err-code: %v", logPrefix, msg.Operation, msg.ErrCode)) } case msg.Operation == "notify": err := m.handle(&msg) if err != nil { - m.logger.Error(fmt.Sprintf("handle message error: %s", err.Error())) + m.logger.Error(fmt.Sprintf("%s: handle message error: %s", logPrefix, err.Error())) } } } diff --git a/htx/usdm/accountws/client_test.go b/htx/usdm/accountws/client_test.go index 5b9d05b..09496ec 100644 --- a/htx/usdm/accountws/client_test.go +++ b/htx/usdm/accountws/client_test.go @@ -18,7 +18,6 @@ package accountws import ( - "context" "fmt" "os" "testing" @@ -27,12 +26,13 @@ import ( "github.com/stretchr/testify/assert" ) -func testNewAccountWsClient(ctx context.Context, t *testing.T, url string) *AccountWsClient { - cli, err := NewAccountWsClient(ctx, &AccountWsClientCfg{ - BaseURL: url, - Debug: true, - Key: os.Getenv("HTX_KEY"), - Secret: os.Getenv("HTX_SECRET"), +func testNewAccountWsClient(t *testing.T, url string) *AccountWsClient { + cli, err := NewAccountWsClient(&AccountWsClientCfg{ + Debug: true, + BaseURL: url, + AutoReconnect: true, + Key: os.Getenv("HTX_KEY"), + Secret: os.Getenv("HTX_SECRET"), }) if err != nil { @@ -43,10 +43,7 @@ func testNewAccountWsClient(ctx context.Context, t *testing.T, url string) *Acco } func TestSubscribeAccountUpdate(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewAccountWsClient(ctx, t, GlobalOrderWsBaseURL) + cli := testNewAccountWsClient(t, GlobalOrderWsBaseURL) topic, err := cli.GetCrossAccountUpdateTopic("ETH-USDT") assert.Nil(t, err) @@ -67,10 +64,7 @@ func TestSubscribeAccountUpdate(t *testing.T) { } func TestSubscribeUnifyAccountUpdate(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewAccountWsClient(ctx, t, GlobalOrderWsBaseURL) + cli := testNewAccountWsClient(t, GlobalOrderWsBaseURL) topic, err := cli.GetUnifyAccountUpdateTopic() assert.Nil(t, err) diff --git a/htx/usdm/accountws/vars.go b/htx/usdm/accountws/vars.go index 64cae89..b567bf4 100644 --- a/htx/usdm/accountws/vars.go +++ b/htx/usdm/accountws/vars.go @@ -21,6 +21,10 @@ const ( GlobalOrderWsBaseURL = "wss://api.hbdm.com/linear-swap-notification" ) +const ( + logPrefix = "htx::usdm::accountws" +) + const ( MaxTryTimes = 5 diff --git a/htx/usdm/marketws/client.go b/htx/usdm/marketws/client.go index 7c46bfd..ee1893d 100644 --- a/htx/usdm/marketws/client.go +++ b/htx/usdm/marketws/client.go @@ -43,7 +43,9 @@ type MarketWsClient struct { // logger logger *slog.Logger - ctx context.Context + stopCtx context.Context + cancel context.CancelFunc + conn *websocket.Conn mu sync.RWMutex isConnected bool @@ -58,13 +60,14 @@ type MarketWsClient struct { } type MarketWsClientCfg struct { - BaseURL string `validate:"required"` - Debug bool + Debug bool + BaseURL string `validate:"required"` + AutoReconnect bool `validate:"required"` // Logger Logger *slog.Logger } -func NewMarketWsClient(ctx context.Context, cfg *MarketWsClientCfg) (*MarketWsClient, error) { +func NewMarketWsClient(cfg *MarketWsClientCfg) (*MarketWsClient, error) { if err := validator.New().Struct(cfg); err != nil { return nil, err } @@ -74,8 +77,7 @@ func NewMarketWsClient(ctx context.Context, cfg *MarketWsClientCfg) (*MarketWsCl debug: cfg.Debug, logger: cfg.Logger, - ctx: ctx, - autoReconnect: true, + autoReconnect: cfg.AutoReconnect, subscriptions: cmap.New[struct{}](), emitter: emission.NewEmitter(), @@ -85,12 +87,33 @@ func NewMarketWsClient(ctx context.Context, cfg *MarketWsClientCfg) (*MarketWsCl cli.logger = slog.Default() } - err := cli.start() + return cli, nil +} + +func (m *MarketWsClient) Open() error { + if m.stopCtx != nil { + return fmt.Errorf("%s: ws is already open", logPrefix) + } + + m.stopCtx, m.cancel = context.WithCancel(context.Background()) + + err := m.start() if err != nil { - return nil, err + return err } - return cli, nil + return nil +} + +func (m *MarketWsClient) Close() error { + if m.stopCtx == nil { + return fmt.Errorf("%s: ws is not open", logPrefix) + } + + m.cancel() + m.stopCtx = nil + + return nil } func (m *MarketWsClient) start() error { @@ -101,7 +124,7 @@ func (m *MarketWsClient) start() error { for i := 0; i < MaxTryTimes; i++ { conn, _, err := m.connect() if err != nil { - m.logger.Info(fmt.Sprintf("connect error, times(%v), error: %s", i, err.Error())) + m.logger.Info(fmt.Sprintf("%s: connect error, times(%v), error: %s", logPrefix, i, err.Error())) tm := (i + 1) * 5 time.Sleep(time.Duration(tm) * time.Second) continue @@ -113,6 +136,8 @@ func (m *MarketWsClient) start() error { return errors.New("connect failed") } + m.logger.Info(fmt.Sprintf("%s: connect success, base_url: %s", logPrefix, m.baseURL)) + m.setIsConnected(true) m.resubscribe() @@ -143,15 +168,14 @@ func (m *MarketWsClient) reconnect() { m.setIsConnected(false) - m.logger.Info("disconnect, then reconnect...") - time.Sleep(1 * time.Second) select { - case <-m.ctx.Done(): - m.logger.Info(fmt.Sprintf("never reconnect, %s", m.ctx.Err())) + case <-m.stopCtx.Done(): + m.logger.Info(fmt.Sprintf("%s: reconnection exits", logPrefix)) return default: + m.logger.Info(fmt.Sprintf("%s: try to reconnect...", logPrefix)) m.start() } } @@ -187,18 +211,20 @@ func (m *MarketWsClient) IsConnected() bool { func (m *MarketWsClient) readMessages() { for { select { - case <-m.ctx.Done(): - m.logger.Info(fmt.Sprintf("context done, error: %s", m.ctx.Err().Error())) + case <-m.stopCtx.Done(): + m.logger.Info(fmt.Sprintf("%s: ready to close...", logPrefix)) if err := m.close(); err != nil { - m.logger.Info(fmt.Sprintf("websocket connection closed error, %s", err.Error())) + m.logger.Error(fmt.Sprintf("%s: connection closed error, %s", logPrefix, err.Error())) + return } + m.logger.Info(fmt.Sprintf("%s: connection closed success", logPrefix)) return default: msgType, buf, err := m.conn.ReadMessage() if err != nil { - m.logger.Error(fmt.Sprintf("read message error, %s", err)) + m.logger.Info(fmt.Sprintf("%s: read message error, %s", logPrefix, err)) time.Sleep(TimerIntervalSecond * time.Second) continue } @@ -207,12 +233,15 @@ func (m *MarketWsClient) readMessages() { if msgType == websocket.BinaryMessage { message, err := htxutils.GZipDecompress(buf) if err != nil { - m.logger.Error(fmt.Sprintf("ungzip data error: %s", err)) + m.logger.Info(fmt.Sprintf("%s: ungzip data error: %s", logPrefix, err)) + m.logger.Info(fmt.Sprintf("%s: ready to close...", logPrefix)) if err := m.close(); err != nil { - m.logger.Error(fmt.Sprintf("websocket connection closed error, %s", err.Error())) + m.logger.Error(fmt.Sprintf("%s: connection closed error, %s", logPrefix, err.Error())) + return } + m.logger.Info(fmt.Sprintf("%s: connection closed success", logPrefix)) return } @@ -220,12 +249,15 @@ func (m *MarketWsClient) readMessages() { err = json.Unmarshal([]byte(message), &msg) if err != nil { - m.logger.Error(fmt.Sprintf("read object error, %s", err)) + m.logger.Info(fmt.Sprintf("%s: read object error, %s", logPrefix, err)) + m.logger.Info(fmt.Sprintf("%s: ready to close...", logPrefix)) if err := m.close(); err != nil { - m.logger.Error(fmt.Sprintf("websocket connection closed error, %s", err.Error())) + m.logger.Error(fmt.Sprintf("%s: connection closed error, %s", logPrefix, err.Error())) + return } + m.logger.Info(fmt.Sprintf("%s: connection closed success", logPrefix)) return } @@ -235,14 +267,14 @@ func (m *MarketWsClient) readMessages() { Pong: msg.Ping.Ping, }) if err != nil { - m.logger.Error(fmt.Sprintf("handle ping error: %s", err.Error())) + m.logger.Error(fmt.Sprintf("%s: handle ping error: %s", logPrefix, err.Error())) } case msg.Response != nil: // todo case msg.SubscribedMessage != nil: err := m.handle(msg.SubscribedMessage) if err != nil { - m.logger.Error(fmt.Sprintf("handle message error: %s", err.Error())) + m.logger.Error(fmt.Sprintf("%s: handle message error: %s", logPrefix, err.Error())) } } } diff --git a/htx/usdm/marketws/client_test.go b/htx/usdm/marketws/client_test.go index 822a260..aca3421 100644 --- a/htx/usdm/marketws/client_test.go +++ b/htx/usdm/marketws/client_test.go @@ -18,7 +18,6 @@ package marketws import ( - "context" "fmt" "testing" @@ -27,8 +26,8 @@ import ( "github.com/stretchr/testify/assert" ) -func testNewMarketWsClient(ctx context.Context, t *testing.T, url string) *MarketWsClient { - cli, err := NewMarketWsClient(ctx, &MarketWsClientCfg{ +func testNewMarketWsClient(t *testing.T, url string) *MarketWsClient { + cli, err := NewMarketWsClient(&MarketWsClientCfg{ BaseURL: url, Debug: true, }) @@ -41,10 +40,7 @@ func testNewMarketWsClient(ctx context.Context, t *testing.T, url string) *Marke } func TestSubscribeKline(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketWsClient(ctx, t, GlobalMarketWsBaseURL) + cli := testNewMarketWsClient(t, GlobalMarketWsBaseURL) topic, err := cli.GetKlineTopic(&KlineTopicParam{ ContractCode: "BTC-USDT", @@ -68,10 +64,7 @@ func TestSubscribeKline(t *testing.T) { } func TestSubscribeDepth(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewMarketWsClient(ctx, t, GlobalMarketWsBaseURL) + cli := testNewMarketWsClient(t, GlobalMarketWsBaseURL) topic, err := cli.GetDepthTopic(&DepthTopicParam{ ContractCode: "BTC-USDT", diff --git a/htx/usdm/marketws/subscriptions.go b/htx/usdm/marketws/subscriptions.go index 8aa5685..b76e357 100644 --- a/htx/usdm/marketws/subscriptions.go +++ b/htx/usdm/marketws/subscriptions.go @@ -35,7 +35,7 @@ func (m *MarketWsClient) UnSubscribe(topic string) error { func (m *MarketWsClient) handle(msg *SubscribedMessage) error { if m.debug { - m.logger.Info(fmt.Sprintf("subscribed message, channel: %s", msg.Channel)) + m.logger.Info(fmt.Sprintf("%s: subscribed message, channel: %s", logPrefix, msg.Channel)) } switch { diff --git a/htx/usdm/marketws/vars.go b/htx/usdm/marketws/vars.go index 4cf7d62..1c50f21 100644 --- a/htx/usdm/marketws/vars.go +++ b/htx/usdm/marketws/vars.go @@ -23,6 +23,10 @@ const ( GlobalSystemStatusWsBaseURL = "wss://api.hbdm.com/center-notification" ) +const ( + logPrefix = "htx::usdm::marketws" +) + const ( MaxTryTimes = 5 diff --git a/woox/websocket/client.go b/woox/websocket/client.go index 0bb8312..00adaef 100644 --- a/woox/websocket/client.go +++ b/woox/websocket/client.go @@ -42,7 +42,9 @@ type WooXWebsocketClient struct { // logger logger *slog.Logger - ctx context.Context + stopCtx context.Context + cancel context.CancelFunc + conn *websocket.Conn mu sync.RWMutex isConnected bool @@ -58,30 +60,34 @@ type WooXWebsocketClient struct { } type WooXWebsocketCfg struct { + Debug bool BaseURL string `validate:"required"` + AutoReconnect bool `validate:"required"` + Key string Secret string ApplicationID string `validate:"required"` - Debug bool + // Logger Logger *slog.Logger } -func NewWooXWebsocketClient(ctx context.Context, cfg *WooXWebsocketCfg) (*WooXWebsocketClient, error) { +func NewWooXWebsocketClient(cfg *WooXWebsocketCfg) (*WooXWebsocketClient, error) { if err := validator.New().Struct(cfg); err != nil { return nil, err } cli := &WooXWebsocketClient{ - baseURL: cfg.BaseURL, + debug: cfg.Debug, + baseURL: cfg.BaseURL, + key: cfg.Key, secret: cfg.Secret, applicationID: cfg.ApplicationID, - debug: cfg.Debug, - logger: cfg.Logger, - ctx: ctx, - autoReconnect: true, + logger: cfg.Logger, + + autoReconnect: cfg.AutoReconnect, subscriptions: cmap.New[struct{}](), emitter: emission.NewEmitter(), @@ -91,12 +97,33 @@ func NewWooXWebsocketClient(ctx context.Context, cfg *WooXWebsocketCfg) (*WooXWe cli.logger = slog.Default() } - err := cli.start() + return cli, nil +} + +func (w *WooXWebsocketClient) Open() error { + if w.stopCtx != nil { + return fmt.Errorf("%s: ws is already open", logPrefix) + } + + w.stopCtx, w.cancel = context.WithCancel(context.Background()) + + err := w.start() if err != nil { - return nil, err + return err } - return cli, nil + return nil +} + +func (w *WooXWebsocketClient) Close() error { + if w.stopCtx == nil { + return fmt.Errorf("%s: ws is not open", logPrefix) + } + + w.cancel() + w.stopCtx = nil + + return nil } func (w *WooXWebsocketClient) start() error { @@ -108,7 +135,7 @@ func (w *WooXWebsocketClient) start() error { for i := 0; i < MaxTryTimes; i++ { conn, _, err := w.connect() if err != nil { - w.logger.Info(fmt.Sprintf("connect error, times(%v), error: %s", i, err.Error())) + w.logger.Info(fmt.Sprintf("%s: connect error, times(%v), error: %s", logPrefix, i, err.Error())) tm := (i + 1) * 5 time.Sleep(time.Duration(tm) * time.Second) continue @@ -120,6 +147,8 @@ func (w *WooXWebsocketClient) start() error { return errors.New("connect failed") } + w.logger.Info(fmt.Sprintf("%s: connect success, base_url: %s", logPrefix, w.baseURL)) + w.setIsConnected(true) w.resubscribe() @@ -152,17 +181,16 @@ func (w *WooXWebsocketClient) reconnect() { w.setIsConnected(false) - w.logger.Info("disconnect, then reconnect...") - close(w.heartCancel) time.Sleep(1 * time.Second) select { - case <-w.ctx.Done(): - w.logger.Info(fmt.Sprintf("never reconnect, %s", w.ctx.Err())) + case <-w.stopCtx.Done(): + w.logger.Info(fmt.Sprintf("%s: reconnection exits", logPrefix)) return default: + w.logger.Info(fmt.Sprintf("%s: try to reconnect...", logPrefix)) w.start() } } @@ -213,24 +241,30 @@ func (w *WooXWebsocketClient) heartbeat() { func (w *WooXWebsocketClient) readMessages() { for { select { - case <-w.ctx.Done(): - w.logger.Info(fmt.Sprintf("context done, error: %s", w.ctx.Err().Error())) + case <-w.stopCtx.Done(): + w.logger.Info(fmt.Sprintf("%s: ready to close...", logPrefix)) if err := w.close(); err != nil { - w.logger.Info(fmt.Sprintf("websocket connection closed error, %s", err.Error())) + w.logger.Error(fmt.Sprintf("%s: connection closed error, %s", logPrefix, err.Error())) + return } + w.logger.Info(fmt.Sprintf("%s: connection closed success", logPrefix)) return default: var msg types.AnyMessage err := w.conn.ReadJSON(&msg) if err != nil { - w.logger.Info(fmt.Sprintf("read object error, %s", err)) + w.logger.Info(fmt.Sprintf("%s: read message error, %s", logPrefix, err)) + w.logger.Info(fmt.Sprintf("%s: ready to close...", logPrefix)) if err := w.close(); err != nil { - w.logger.Info(fmt.Sprintf("websocket connection closed error, %s", err.Error())) + w.logger.Error(fmt.Sprintf("%s: connection closed error, %s", logPrefix, err.Error())) + return } + w.logger.Info(fmt.Sprintf("%s: connection closed success", logPrefix)) + return } @@ -240,7 +274,7 @@ func (w *WooXWebsocketClient) readMessages() { case msg.SubscribedMessage != nil: err := w.handle(msg.SubscribedMessage) if err != nil { - w.logger.Info(fmt.Sprintf("handle message error: %s", err.Error())) + w.logger.Info(fmt.Sprintf("%s: handle message error: %s", logPrefix, err.Error())) } } } diff --git a/woox/websocket/client_test.go b/woox/websocket/client_test.go index cac0845..4dbd7e6 100644 --- a/woox/websocket/client_test.go +++ b/woox/websocket/client_test.go @@ -18,7 +18,6 @@ package websocket import ( - "context" "fmt" "os" "testing" @@ -28,13 +27,14 @@ import ( "github.com/stretchr/testify/assert" ) -func testNewWooXWebsocketClient(ctx context.Context, t *testing.T) *WooXWebsocketClient { - cli, err := NewWooXWebsocketClient(ctx, &WooXWebsocketCfg{ +func testNewWooXWebsocketClient(t *testing.T) *WooXWebsocketClient { + cli, err := NewWooXWebsocketClient(&WooXWebsocketCfg{ + Debug: true, BaseURL: TestNetPublicBaseURL, + AutoReconnect: true, Key: os.Getenv("WOOX_KEY"), Secret: os.Getenv("WOOX_SECRET"), ApplicationID: os.Getenv("WOOX_APP_ID"), // required - Debug: true, }) if err != nil { @@ -45,10 +45,7 @@ func testNewWooXWebsocketClient(ctx context.Context, t *testing.T) *WooXWebsocke } func TestWooXWebsocketClientConnection(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewWooXWebsocketClient(ctx, t) + cli := testNewWooXWebsocketClient(t) err := cli.Subscribe([]string{"SPOT_WOO_USDT@orderbook"}) assert.Nil(t, err) @@ -57,10 +54,7 @@ func TestWooXWebsocketClientConnection(t *testing.T) { } func TestSubscribeOrderbook(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewWooXWebsocketClient(ctx, t) + cli := testNewWooXWebsocketClient(t) topic, err := cli.GetOrderbookTopic("PERP_BTC_USDT") assert.Nil(t, err) @@ -93,10 +87,7 @@ func TestSubscribeOrderbook(t *testing.T) { } func TestSubscribeTrade(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewWooXWebsocketClient(ctx, t) + cli := testNewWooXWebsocketClient(t) topic, err := cli.GetTradeTopic("PERP_BTC_USDT") assert.Nil(t, err) @@ -121,10 +112,7 @@ func TestSubscribeTrade(t *testing.T) { } func TestSubscribeTickerForSymbol(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewWooXWebsocketClient(ctx, t) + cli := testNewWooXWebsocketClient(t) topic, err := cli.GetTickerTopic("PERP_BTC_USDT") assert.Nil(t, err) @@ -149,10 +137,7 @@ func TestSubscribeTickerForSymbol(t *testing.T) { } func TestSubscribeTickers(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewWooXWebsocketClient(ctx, t) + cli := testNewWooXWebsocketClient(t) topic, err := cli.GetAllTickersTopic() assert.Nil(t, err) @@ -179,10 +164,7 @@ func TestSubscribeTickers(t *testing.T) { } func TestSubscribeBBOForSymbol(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewWooXWebsocketClient(ctx, t) + cli := testNewWooXWebsocketClient(t) topic, err := cli.GetBboTopic("PERP_BTC_USDT") assert.Nil(t, err) @@ -207,10 +189,7 @@ func TestSubscribeBBOForSymbol(t *testing.T) { } func TestSubscribeBBOs(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewWooXWebsocketClient(ctx, t) + cli := testNewWooXWebsocketClient(t) topic, err := cli.GetAllBbosTopic() assert.Nil(t, err) @@ -237,10 +216,7 @@ func TestSubscribeBBOs(t *testing.T) { } func TestSubscribeKline(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewWooXWebsocketClient(ctx, t) + cli := testNewWooXWebsocketClient(t) topic, err := cli.GetKlineTopic(&KlineTopicParam{ Symbol: "PERP_BTC_USDT", @@ -268,10 +244,7 @@ func TestSubscribeKline(t *testing.T) { } func TestSubscribeIndexPrice(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewWooXWebsocketClient(ctx, t) + cli := testNewWooXWebsocketClient(t) topic, err := cli.GetIndexPriceTopic("SPOT_ETH_USDT") assert.Nil(t, err) @@ -295,10 +268,7 @@ func TestSubscribeIndexPrice(t *testing.T) { } func TestSubscribeMarkPrice(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewWooXWebsocketClient(ctx, t) + cli := testNewWooXWebsocketClient(t) topic, err := cli.GetMarkPriceTopic("PERP_BTC_USDT") assert.Nil(t, err) @@ -322,10 +292,7 @@ func TestSubscribeMarkPrice(t *testing.T) { } func TestSubscribeAllMarkPrice(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewWooXWebsocketClient(ctx, t) + cli := testNewWooXWebsocketClient(t) topic, err := cli.GetMarkPricesTopic() assert.Nil(t, err) @@ -351,10 +318,7 @@ func TestSubscribeAllMarkPrice(t *testing.T) { } func TestSubscribeOpenInterest(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewWooXWebsocketClient(ctx, t) + cli := testNewWooXWebsocketClient(t) topic, err := cli.GetOpenInterestTopic("PERP_BTC_USDT") assert.Nil(t, err) @@ -378,10 +342,7 @@ func TestSubscribeOpenInterest(t *testing.T) { } func TestSubscribeEstFundingRate(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - cli := testNewWooXWebsocketClient(ctx, t) + cli := testNewWooXWebsocketClient(t) topic, err := cli.GetEstFundingRateTopic("PERP_BTC_USDT") assert.Nil(t, err) diff --git a/woox/websocket/subscriptions.go b/woox/websocket/subscriptions.go index 611fcba..da4e73d 100644 --- a/woox/websocket/subscriptions.go +++ b/woox/websocket/subscriptions.go @@ -35,7 +35,7 @@ func (w *WooXWebsocketClient) UnSubscribe(topics []string) error { func (w *WooXWebsocketClient) handle(msg *types.SubscribedMessage) error { if w.debug { - w.logger.Info(fmt.Sprintf("subscribed message, topic: %s, timestamp: %v", msg.Topic, msg.Timestamp)) + w.logger.Info(fmt.Sprintf("%s: subscribed message, topic: %s, timestamp: %v", logPrefix, msg.Topic, msg.Timestamp)) } switch { diff --git a/woox/websocket/vars.go b/woox/websocket/vars.go index 34961b3..d316fa0 100644 --- a/woox/websocket/vars.go +++ b/woox/websocket/vars.go @@ -25,6 +25,10 @@ var ( PrivateBaseURL = "wss://wss.woo.org/v2/ws/private/stream/" ) +const ( + logPrefix = "woox::websocket" +) + const ( MaxTryTimes = 5 )