Skip to content

Commit

Permalink
Use functional option for setting logger in dmsg.
Browse files Browse the repository at this point in the history
  • Loading branch information
林志宇 committed Jun 20, 2019
1 parent e316d53 commit 2d531ae
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 31 deletions.
29 changes: 22 additions & 7 deletions pkg/dmsg/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,20 @@ func (c *ClientConn) Close() error {
return nil
}

// ClientOption represents an optional argument for Client.
type ClientOption func(c *Client) error

// SetLogger sets the internal logger for Client.
func SetLogger(log *logging.Logger) ClientOption {
return func(c *Client) error {
if log == nil {
return errors.New("nil logger set")
}
c.log = log
return nil
}
}

// Client implements transport.Factory
type Client struct {
log *logging.Logger
Expand All @@ -289,8 +303,8 @@ type Client struct {
}

// NewClient creates a new Client.
func NewClient(pk cipher.PubKey, sk cipher.SecKey, dc client.APIClient) *Client {
return &Client{
func NewClient(pk cipher.PubKey, sk cipher.SecKey, dc client.APIClient, opts ...ClientOption) *Client {
c := &Client{
log: logging.MustGetLogger("dmsg_client"),
pk: pk,
sk: sk,
Expand All @@ -299,11 +313,12 @@ func NewClient(pk cipher.PubKey, sk cipher.SecKey, dc client.APIClient) *Client
accept: make(chan *Transport, AcceptBufferSize),
done: make(chan struct{}),
}
}

// SetLogger sets the dms_client's logger.
func (c *Client) SetLogger(log *logging.Logger) {
c.log = log
for _, opt := range opts {
if err := opt(c); err != nil {
panic(err)
}
}
return c
}

func (c *Client) updateDiscEntry(ctx context.Context) error {
Expand Down
30 changes: 10 additions & 20 deletions pkg/dmsg/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,11 @@ func TestServer_Serve(t *testing.T) {
aPK, aSK := cipher.GenerateKeyPair()
bPK, bSK := cipher.GenerateKeyPair()

a := NewClient(aPK, aSK, dc)
a.SetLogger(logging.MustGetLogger("A"))
a := NewClient(aPK, aSK, dc, SetLogger(logging.MustGetLogger("A")))
err := a.InitiateServerConnections(context.Background(), 1)
require.NoError(t, err)

b := NewClient(bPK, bSK, dc)
b.SetLogger(logging.MustGetLogger("B"))
b := NewClient(bPK, bSK, dc, SetLogger(logging.MustGetLogger("B")))
err = b.InitiateServerConnections(context.Background(), 1)
require.NoError(t, err)

Expand Down Expand Up @@ -315,8 +313,7 @@ func TestServer_Serve(t *testing.T) {
for i := 0; i < initiatorsCount; i++ {
pk, sk := cipher.GenerateKeyPair()

c := NewClient(pk, sk, dc)
c.SetLogger(logging.MustGetLogger(fmt.Sprintf("Initiator %d", i)))
c := NewClient(pk, sk, dc, SetLogger(logging.MustGetLogger(fmt.Sprintf("initiator_%d", i))))
err := c.InitiateServerConnections(context.Background(), 1)
require.NoError(t, err)

Expand All @@ -327,8 +324,7 @@ func TestServer_Serve(t *testing.T) {
for i := 0; i < remotesCount; i++ {
pk, sk := cipher.GenerateKeyPair()

c := NewClient(pk, sk, dc)
c.SetLogger(logging.MustGetLogger(fmt.Sprintf("Remote %d", i)))
c := NewClient(pk, sk, dc, SetLogger(logging.MustGetLogger(fmt.Sprintf("remote_%d", i))))
if _, ok := usedRemotes[i]; ok {
err := c.InitiateServerConnections(context.Background(), 1)
require.NoError(t, err)
Expand Down Expand Up @@ -550,14 +546,12 @@ func TestServer_Serve(t *testing.T) {
bPK, bSK := cipher.GenerateKeyPair()

// create remote
a := NewClient(aPK, aSK, dc)
a.SetLogger(logging.MustGetLogger("A"))
a := NewClient(aPK, aSK, dc, SetLogger(logging.MustGetLogger("A")))
err = a.InitiateServerConnections(context.Background(), 1)
require.NoError(t, err)

// create initiator
b := NewClient(bPK, bSK, dc)
b.SetLogger(logging.MustGetLogger("B"))
b := NewClient(bPK, bSK, dc, SetLogger(logging.MustGetLogger("B")))
err = b.InitiateServerConnections(context.Background(), 1)
require.NoError(t, err)

Expand Down Expand Up @@ -650,14 +644,12 @@ func TestServer_Serve(t *testing.T) {
bPK, bSK := cipher.GenerateKeyPair()

// create remote
a := NewClient(aPK, aSK, dc)
a.SetLogger(logging.MustGetLogger("A"))
a := NewClient(aPK, aSK, dc, SetLogger(logging.MustGetLogger("A")))
err = a.InitiateServerConnections(context.Background(), 1)
require.NoError(t, err)

// create initiator
b := NewClient(bPK, bSK, dc)
b.SetLogger(logging.MustGetLogger("B"))
b := NewClient(bPK, bSK, dc, SetLogger(logging.MustGetLogger("B")))
err = b.InitiateServerConnections(context.Background(), 1)
require.NoError(t, err)

Expand Down Expand Up @@ -747,12 +739,10 @@ func TestNewClient(t *testing.T) {

go s.Serve() //nolint:errcheck

a := NewClient(aPK, aSK, dc)
a.SetLogger(logging.MustGetLogger("A"))
a := NewClient(aPK, aSK, dc, SetLogger(logging.MustGetLogger("A")))
require.NoError(t, a.InitiateServerConnections(context.Background(), 1))

b := NewClient(bPK, bSK, dc)
b.SetLogger(logging.MustGetLogger("B"))
b := NewClient(bPK, bSK, dc, SetLogger(logging.MustGetLogger("B")))
require.NoError(t, b.InitiateServerConnections(context.Background(), 1))

wg := new(sync.WaitGroup)
Expand Down
3 changes: 1 addition & 2 deletions pkg/node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,7 @@ func NewNode(config *Config) (*Node, error) {
return nil, fmt.Errorf("invalid Messaging config: %s", err)
}

node.messenger = dmsg.NewClient(mConfig.PubKey, mConfig.SecKey, mConfig.Discovery)
node.messenger.SetLogger(node.Logger.PackageLogger(dmsg.Type))
node.messenger = dmsg.NewClient(mConfig.PubKey, mConfig.SecKey, mConfig.Discovery, dmsg.SetLogger(node.Logger.PackageLogger(dmsg.Type)))

trDiscovery, err := config.TransportDiscovery()
if err != nil {
Expand Down
3 changes: 1 addition & 2 deletions pkg/setup/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ func NewNode(conf *Config, metrics metrics.Recorder) (*Node, error) {
if lvl, err := logging.LevelFromString(conf.LogLevel); err == nil {
logger.SetLevel(lvl)
}
messenger := dmsg.NewClient(pk, sk, mClient.NewHTTP(conf.Messaging.Discovery))
messenger.SetLogger(logger.PackageLogger(dmsg.Type))
messenger := dmsg.NewClient(pk, sk, mClient.NewHTTP(conf.Messaging.Discovery), dmsg.SetLogger(logger.PackageLogger(dmsg.Type)))

trDiscovery, err := trClient.NewHTTP(conf.TransportDiscovery, pk, sk)
if err != nil {
Expand Down

0 comments on commit 2d531ae

Please sign in to comment.