Skip to content

Commit

Permalink
Merge pull request #42 from dyamin/main
Browse files Browse the repository at this point in the history
Add Multi-Instrument Subscription and Custom Dialer Support
  • Loading branch information
amir-the-h authored Jun 5, 2024
2 parents 610b577 + 1501845 commit 1f0fb11
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 167 deletions.
138 changes: 83 additions & 55 deletions api/ws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/amir-the-h/okex"
"github.com/amir-the-h/okex/events"
"github.com/gorilla/websocket"
"io"
"net/http"
"sync"
"time"
Expand All @@ -31,6 +32,7 @@ type ClientWs struct {
sendChan map[bool]chan []byte
url map[bool]okex.BaseURL
conn map[bool]*websocket.Conn
dialer *websocket.Dialer
apiKey string
secretKey []byte
passphrase string
Expand All @@ -55,19 +57,18 @@ const (
func NewClient(ctx context.Context, apiKey, secretKey, passphrase string, url map[bool]okex.BaseURL) *ClientWs {
ctx, cancel := context.WithCancel(ctx)
c := &ClientWs{
apiKey: apiKey,
secretKey: []byte(secretKey),
passphrase: passphrase,
ctx: ctx,
Cancel: cancel,
url: url,
sendChan: map[bool]chan []byte{true: make(chan []byte, 3), false: make(chan []byte, 3)},
DoneChan: make(chan interface{}),
StructuredEventChan: make(chan interface{}),
RawEventChan: make(chan *events.Basic),
conn: make(map[bool]*websocket.Conn),
lastTransmit: make(map[bool]*time.Time),
mu: map[bool]*sync.RWMutex{true: {}, false: {}},
apiKey: apiKey,
secretKey: []byte(secretKey),
passphrase: passphrase,
ctx: ctx,
Cancel: cancel,
url: url,
sendChan: map[bool]chan []byte{true: make(chan []byte, 3), false: make(chan []byte, 3)},
DoneChan: make(chan interface{}),
conn: make(map[bool]*websocket.Conn),
dialer: websocket.DefaultDialer,
lastTransmit: make(map[bool]*time.Time),
mu: map[bool]*sync.RWMutex{true: {}, false: {}},
}
c.Private = NewPrivate(c)
c.Public = NewPublic(c)
Expand Down Expand Up @@ -131,20 +132,24 @@ func (c *ClientWs) Login() error {
// Users can choose to subscribe to one or more channels, and the total length of multiple channels cannot exceed 4096 bytes.
//
// https://www.okex.com/docs-v5/en/#websocket-api-subscribe
func (c *ClientWs) Subscribe(p bool, ch []okex.ChannelName, args map[string]string) error {
count := 1
if len(ch) != 0 {
count = len(ch)
}
tmpArgs := make([]map[string]string, count)
tmpArgs[0] = args
for i, name := range ch {
tmpArgs[i] = map[string]string{}
tmpArgs[i]["channel"] = string(name)
for k, v := range args {
tmpArgs[i][k] = v
func (c *ClientWs) Subscribe(p bool, ch []okex.ChannelName, args ...map[string]string) error {
chCount := max(len(ch), 1)
tmpArgs := make([]map[string]string, chCount*len(args))

n := 0
for i := 0; i < chCount; i++ {
for _, arg := range args {
tmpArgs[n] = make(map[string]string)
for k, v := range arg {
tmpArgs[n][k] = v
}
if len(ch) > 0 {
tmpArgs[n]["channel"] = string(ch[i])
}
n++
}
}

return c.Send(p, okex.SubscribeOperation, tmpArgs)
}

Expand Down Expand Up @@ -205,6 +210,16 @@ func (c *ClientWs) SetChannels(errCh chan *events.Error, subCh chan *events.Subs
c.SuccessChan = sCh
}

// SetDialer sets a custom dialer for the WebSocket connection.
func (c *ClientWs) SetDialer(dialer *websocket.Dialer) {
c.dialer = dialer
}

func (c *ClientWs) SetEventChannels(structuredEventCh chan interface{}, rawEventCh chan *events.Basic) {
c.StructuredEventChan = structuredEventCh
c.RawEventChan = rawEventCh
}

// WaitForAuthorization waits for the auth response and try to log in if it was needed
func (c *ClientWs) WaitForAuthorization() error {
if c.Authorized {
Expand All @@ -225,16 +240,23 @@ func (c *ClientWs) WaitForAuthorization() error {

func (c *ClientWs) dial(p bool) error {
c.mu[p].Lock()
conn, res, err := websocket.DefaultDialer.Dial(string(c.url[p]), nil)
conn, res, err := c.dialer.Dial(string(c.url[p]), nil)
if err != nil {
var statusCode int
if res != nil {
statusCode = res.StatusCode
}
c.mu[p].Unlock()
return fmt.Errorf("error %d: %w", statusCode, err)
}
defer res.Body.Close()
c.conn[p] = conn
c.mu[p].Unlock()

defer func(Body io.ReadCloser) {
err := Body.Close()
if err != nil {
fmt.Printf("error closing body: %v\n", err)
}
}(res.Body)
go func() {
err := c.receiver(p)
if err != nil {
Expand All @@ -247,10 +269,10 @@ func (c *ClientWs) dial(p bool) error {
fmt.Printf("sender error: %v\n", err)
}
}()
c.conn[p] = conn
c.mu[p].Unlock()

return nil
}

func (c *ClientWs) sender(p bool) error {
ticker := time.NewTicker(time.Millisecond * 300)
defer ticker.Stop()
Expand Down Expand Up @@ -279,7 +301,11 @@ func (c *ClientWs) sender(p bool) error {
return err
}
case <-ticker.C:
if c.conn[p] != nil && (c.lastTransmit[p] == nil || (c.lastTransmit[p] != nil && time.Since(*c.lastTransmit[p]) > PingPeriod)) {
c.mu[p].RLock()
conn := c.conn[p]
lastTransmit := c.lastTransmit[p]
c.mu[p].RUnlock()
if conn != nil && (lastTransmit == nil || (lastTransmit != nil && time.Since(*lastTransmit) > PingPeriod)) {
go func() {
c.sendChan[p] <- []byte("ping")
}()
Expand All @@ -289,6 +315,7 @@ func (c *ClientWs) sender(p bool) error {
}
}
}

func (c *ClientWs) receiver(p bool) error {
for {
select {
Expand Down Expand Up @@ -326,6 +353,7 @@ func (c *ClientWs) receiver(p bool) error {
}
}
}

func (c *ClientWs) sign(method, path string) (string, string) {
t := time.Now().UTC().Unix()
ts := fmt.Sprint(t)
Expand All @@ -335,42 +363,42 @@ func (c *ClientWs) sign(method, path string) (string, string) {
h.Write(p)
return ts, base64.StdEncoding.EncodeToString(h.Sum(nil))
}

func (c *ClientWs) handleCancel(msg string) error {
go func() {
c.DoneChan <- msg
}()
return fmt.Errorf("operation cancelled: %s", msg)
}

// TODO: break each case into a separate function
func (c *ClientWs) process(data []byte, e *events.Basic) bool {
switch e.Event {
case "error":
e := events.Error{}
_ = json.Unmarshal(data, &e)
go func() {
if c.ErrChan != nil {
c.ErrChan <- &e
}()
}
return true
case "subscribe":
e := events.Subscribe{}
_ = json.Unmarshal(data, &e)
go func() {
if c.SubscribeChan != nil {
c.SubscribeChan <- &e
}
if c.SubscribeChan != nil {
c.SubscribeChan <- &e
}
if c.StructuredEventChan != nil {
c.StructuredEventChan <- e
}()
}
return true
case "unsubscribe":
e := events.Unsubscribe{}
_ = json.Unmarshal(data, &e)
go func() {
if c.UnsubscribeCh != nil {
c.UnsubscribeCh <- &e
}
if c.UnsubscribeCh != nil {
c.UnsubscribeCh <- &e
}
if c.StructuredEventChan != nil {
c.StructuredEventChan <- e
}()
}
return true
case "login":
if time.Since(*c.AuthRequested).Seconds() > 30 {
Expand All @@ -381,12 +409,12 @@ func (c *ClientWs) process(data []byte, e *events.Basic) bool {
c.Authorized = true
e := events.Login{}
_ = json.Unmarshal(data, &e)
go func() {
if c.LoginChan != nil {
c.LoginChan <- &e
}
if c.LoginChan != nil {
c.LoginChan <- &e
}
if c.StructuredEventChan != nil {
c.StructuredEventChan <- e
}()
}
return true
}
if c.Private.Process(data, e) {
Expand All @@ -403,14 +431,14 @@ func (c *ClientWs) process(data []byte, e *events.Basic) bool {
}
e := events.Success{}
_ = json.Unmarshal(data, &e)
go func() {
if c.SuccessChan != nil {
c.SuccessChan <- &e
}
if c.SuccessChan != nil {
c.SuccessChan <- &e
}
if c.StructuredEventChan != nil {
c.StructuredEventChan <- e
}()
}
return true
}
go func() { c.RawEventChan <- e }()
c.RawEventChan <- e
return false
}
Loading

0 comments on commit 1f0fb11

Please sign in to comment.