diff --git a/client/requestSessionRestoration.go b/client/requestSessionRestoration.go index b8c7f63..0610ecd 100644 --- a/client/requestSessionRestoration.go +++ b/client/requestSessionRestoration.go @@ -21,7 +21,6 @@ func (clt *Client) requestSessionRestoration(sessionKey []byte) (*webwire.Sessio ) if err != nil { // TODO: check for error types - fmt.Println("ERR", err) return nil, fmt.Errorf("Session restoration request failed: %s", err) } diff --git a/server.go b/server.go index 90f155f..9fe7052 100644 --- a/server.go +++ b/server.go @@ -134,7 +134,7 @@ func NewServer(opts Options) *Server { clients: make([]*Client, 0), clientsLock: &sync.Mutex{}, sessionsEnabled: sessionsEnabled, - SessionRegistry: newSessionRegistry(), + SessionRegistry: newSessionRegistry(opts.MaxSessionConnections), // Internals warnLog: log.New( @@ -172,14 +172,14 @@ func (srv *Server) handleSessionRestore(msg *Message) error { key := string(msg.Payload.Data) - sessionExists := srv.SessionRegistry.Exists(key) - - if sessionExists { + if srv.SessionRegistry.maxConns > 0 && + srv.SessionRegistry.SessionConnections(key)+1 > srv.SessionRegistry.maxConns { msg.fail(Error{ - "SESSION_ACTIVE", - fmt.Sprintf( - "The session identified by key: '%s' is already active", + Code: "MAX_CONN_REACHED", + Message: fmt.Sprintf( + "Session %s reached the maximum number of concurrent connections (%d)", key, + srv.SessionRegistry.maxConns, ), }) return nil @@ -190,6 +190,7 @@ func (srv *Server) handleSessionRestore(msg *Message) error { msg.fail(Error{ "INTERNAL_ERROR", fmt.Sprintf( + // TODO: whoops, that's some master-yoda-style english, fix it "Session restoration request not could have been fulfilled", ), }) @@ -218,7 +219,9 @@ func (srv *Server) handleSessionRestore(msg *Message) error { } msg.Client.Session = session - srv.SessionRegistry.register(msg.Client) + if okay := srv.SessionRegistry.register(msg.Client); !okay { + panic(fmt.Errorf("The number of concurrent session connections was unexpectedly exceeded")) + } msg.fulfill(Payload{ Encoding: EncodingUtf8, @@ -378,7 +381,7 @@ func (srv *Server) ServeHTTP( _, message, err := conn.ReadMessage() if err != nil { if newClient.Session != nil { - // Mark session as inactive + // Decrement number of connections for this clients session srv.SessionRegistry.deregister(newClient) } diff --git a/sessionRegistry.go b/sessionRegistry.go index f9acbfe..49337d8 100644 --- a/sessionRegistry.go +++ b/sessionRegistry.go @@ -4,46 +4,91 @@ import ( "sync" ) +// sessionRegistryEntry represents a session registry entry +type sessionRegistryEntry struct { + connections uint + client *Client +} + // sessionRegistry represents a thread safe registry of all currently active sessions type sessionRegistry struct { lock sync.RWMutex - registry map[string]*Client + maxConns uint + registry map[string]sessionRegistryEntry } -// newSessionRegistry returns a new instance of a session registry -func newSessionRegistry() sessionRegistry { +// newSessionRegistry returns a new instance of a session registry. +// maxConns defines the maximum number of concurrent connections for a single session +// while zero stands for unlimited +func newSessionRegistry(maxConns uint) sessionRegistry { return sessionRegistry{ lock: sync.RWMutex{}, - registry: make(map[string]*Client), + maxConns: maxConns, + registry: make(map[string]sessionRegistryEntry), } } -// register registers the given clients session as a currently active session -func (asr *sessionRegistry) register(clt *Client) { +// register registers a new connection for the given clients session and returns true. +// Returns false if the given clients session already has the max number of connections assigned. +func (asr *sessionRegistry) register(clt *Client) bool { asr.lock.Lock() - asr.registry[clt.Session.Key] = clt - asr.lock.Unlock() + defer asr.lock.Unlock() + if entry, exists := asr.registry[clt.Session.Key]; exists { + // Ensure max connections isn't exceeded + if asr.maxConns > 0 && entry.connections+1 > asr.maxConns { + return false + } + // Overwrite the current entry incrementing the number of connections + asr.registry[clt.Session.Key] = sessionRegistryEntry{ + connections: entry.connections + 1, + client: entry.client, + } + return true + } + asr.registry[clt.Session.Key] = sessionRegistryEntry{ + connections: 1, + client: clt, + } + return true } -// deregister deregisters the given clients session from the list of currently active sessions -func (asr *sessionRegistry) deregister(clt *Client) { +// deregister decrements the number of connections assigned to the given clients session +// and returns true. If there's only one connection left then the session will be removed +// from the register and false will be returned +func (asr *sessionRegistry) deregister(clt *Client) bool { asr.lock.Lock() - delete(asr.registry, clt.Session.Key) - asr.lock.Unlock() + defer asr.lock.Unlock() + if entry, exists := asr.registry[clt.Session.Key]; exists { + // If a single connection is left then remove the session + if entry.connections < 2 { + delete(asr.registry, clt.Session.Key) + return false + } + // Overwrite the current entry decrementing the number of connections + asr.registry[clt.Session.Key] = sessionRegistryEntry{ + connections: entry.connections - 1, + client: entry.client, + } + } + return false } -// Len returns the number of currently active sessions -func (asr *sessionRegistry) Len() int { +// ActiveSessions returns the number of currently active sessions +func (asr *sessionRegistry) ActiveSessions() int { asr.lock.RLock() len := len(asr.registry) asr.lock.RUnlock() return len } -// Exists returns true if the session associated with the given key exists and is currently active -func (asr *sessionRegistry) Exists(sessionKey string) bool { +// SessionConnections returns the number of concurrent connections +// associated with the session associated with the given key. +// Returns zero if the session associated with the given key doesn't exist. +func (asr *sessionRegistry) SessionConnections(sessionKey string) uint { asr.lock.RLock() - _, exists := asr.registry[sessionKey] - asr.lock.RUnlock() - return exists + defer asr.lock.RUnlock() + if sess, exists := asr.registry[sessionKey]; exists { + return sess.connections + } + return 0 } diff --git a/setupServer.go b/setupServer.go index 60ab78c..d527984 100644 --- a/setupServer.go +++ b/setupServer.go @@ -11,10 +11,11 @@ import ( // Options represents the options for a headed server setup type Options struct { - Addr string - Hooks Hooks - WarnLog io.Writer - ErrorLog io.Writer + Addr string + Hooks Hooks + MaxSessionConnections uint + WarnLog io.Writer + ErrorLog io.Writer } // SetDefaults sets default values to undefined options diff --git a/test/activeSessionRegistry_test.go b/test/activeSessionRegistry_test.go index d4a6ec7..5922096 100644 --- a/test/activeSessionRegistry_test.go +++ b/test/activeSessionRegistry_test.go @@ -17,36 +17,38 @@ func TestActiveSessionRegistry(t *testing.T) { // Initialize webwire server srv, addr := setupServer( t, - webwire.Hooks{ - OnRequest: func(ctx context.Context) (webwire.Payload, error) { - // Extract request message and requesting client from the context - msg := ctx.Value(webwire.Msg).(webwire.Message) + webwire.Options{ + Hooks: webwire.Hooks{ + OnRequest: func(ctx context.Context) (webwire.Payload, error) { + // Extract request message and requesting client from the context + msg := ctx.Value(webwire.Msg).(webwire.Message) - // Close session on logout - if msg.Name == "logout" { - if err := msg.Client.CloseSession(); err != nil { - t.Errorf("Couldn't close session: %s", err) + // Close session on logout + if msg.Name == "logout" { + if err := msg.Client.CloseSession(); err != nil { + t.Errorf("Couldn't close session: %s", err) + } + return webwire.Payload{}, nil } - return webwire.Payload{}, nil - } - // Try to create a new session - if err := msg.Client.CreateSession(nil); err != nil { - return webwire.Payload{}, webwire.Error{ - Code: "INTERNAL_ERROR", - Message: fmt.Sprintf("Internal server error: %s", err), + // Try to create a new session + if err := msg.Client.CreateSession(nil); err != nil { + return webwire.Payload{}, webwire.Error{ + Code: "INTERNAL_ERROR", + Message: fmt.Sprintf("Internal server error: %s", err), + } } - } - // Return the key of the newly created session (use default binary encoding) - return webwire.Payload{ - Data: []byte(msg.Client.Session.Key), - }, nil + // Return the key of the newly created session (use default binary encoding) + return webwire.Payload{ + Data: []byte(msg.Client.Session.Key), + }, nil + }, + // Define dummy hooks for sessions to be enabled on this server + OnSessionCreated: func(_ *webwire.Client) error { return nil }, + OnSessionLookup: func(_ string) (*webwire.Session, error) { return nil, nil }, + OnSessionClosed: func(_ *webwire.Client) error { return nil }, }, - // Define dummy hooks for sessions to be enabled on this server - OnSessionCreated: func(_ *webwire.Client) error { return nil }, - OnSessionLookup: func(_ string) (*webwire.Session, error) { return nil, nil }, - OnSessionClosed: func(_ *webwire.Client) error { return nil }, }, ) @@ -75,7 +77,7 @@ func TestActiveSessionRegistry(t *testing.T) { t.Fatalf("Request failed: %s", err) } - activeSessionNumberBefore := srv.SessionRegistry.Len() + activeSessionNumberBefore := srv.SessionRegistry.ActiveSessions() if activeSessionNumberBefore != 1 { t.Fatalf( "Unexpected active session number after authentication: %d", @@ -94,7 +96,7 @@ func TestActiveSessionRegistry(t *testing.T) { t.Fatalf("Request failed: %s", err) } - activeSessionNumberAfter := srv.SessionRegistry.Len() + activeSessionNumberAfter := srv.SessionRegistry.ActiveSessions() if activeSessionNumberAfter != 0 { t.Fatalf("Unexpected active session number after logout: %d", activeSessionNumberAfter) } diff --git a/test/authentication_test.go b/test/authentication_test.go index 67889cd..ed1184a 100644 --- a/test/authentication_test.go +++ b/test/authentication_test.go @@ -65,45 +65,47 @@ func TestAuthentication(t *testing.T) { // Initialize webwire server _, addr := setupServer( t, - wwr.Hooks{ - OnSignal: func(ctx context.Context) { - defer clientSignalReceived.Done() - // Extract request message and requesting client from the context - msg := ctx.Value(wwr.Msg).(wwr.Message) - compareSessions(t, createdSession, msg.Client.Session) - compareSessionInfo(msg.Client.Session) - }, - OnRequest: func(ctx context.Context) (wwr.Payload, error) { - // Extract request message and requesting client from the context - msg := ctx.Value(wwr.Msg).(wwr.Message) - - // If already authenticated then check session - if currentStep > 1 { + wwr.Options{ + Hooks: wwr.Hooks{ + OnSignal: func(ctx context.Context) { + defer clientSignalReceived.Done() + // Extract request message and requesting client from the context + msg := ctx.Value(wwr.Msg).(wwr.Message) compareSessions(t, createdSession, msg.Client.Session) compareSessionInfo(msg.Client.Session) - return expectedConfirmation, nil - } - - // Try to create a new session - if err := msg.Client.CreateSession(sessionInfo); err != nil { - return wwr.Payload{}, wwr.Error{ - Code: "INTERNAL_ERROR", - Message: fmt.Sprintf("Internal server error: %s", err), + }, + OnRequest: func(ctx context.Context) (wwr.Payload, error) { + // Extract request message and requesting client from the context + msg := ctx.Value(wwr.Msg).(wwr.Message) + + // If already authenticated then check session + if currentStep > 1 { + compareSessions(t, createdSession, msg.Client.Session) + compareSessionInfo(msg.Client.Session) + return expectedConfirmation, nil } - } - // Authentication step is passed - currentStep = 2 + // Try to create a new session + if err := msg.Client.CreateSession(sessionInfo); err != nil { + return wwr.Payload{}, wwr.Error{ + Code: "INTERNAL_ERROR", + Message: fmt.Sprintf("Internal server error: %s", err), + } + } - // Return the key of the newly created session (use default binary encoding) - return wwr.Payload{ - Data: []byte(msg.Client.Session.Key), - }, nil + // Authentication step is passed + currentStep = 2 + + // Return the key of the newly created session (use default binary encoding) + return wwr.Payload{ + Data: []byte(msg.Client.Session.Key), + }, nil + }, + // Define dummy hooks to enable sessions on this server + OnSessionCreated: func(_ *wwr.Client) error { return nil }, + OnSessionLookup: func(_ string) (*wwr.Session, error) { return nil, nil }, + OnSessionClosed: func(_ *wwr.Client) error { return nil }, }, - // Define dummy hooks to enable sessions on this server - OnSessionCreated: func(_ *wwr.Client) error { return nil }, - OnSessionLookup: func(_ string) (*wwr.Session, error) { return nil, nil }, - OnSessionClosed: func(_ *wwr.Client) error { return nil }, }, ) diff --git a/test/clientAutomaticSessionRestoration_test.go b/test/clientAutomaticSessionRestoration_test.go index f00b4b3..e5c4c76 100644 --- a/test/clientAutomaticSessionRestoration_test.go +++ b/test/clientAutomaticSessionRestoration_test.go @@ -22,60 +22,62 @@ func TestClientAutomaticSessionRestoration(t *testing.T) { // Initialize webwire server _, addr := setupServer( t, - webwire.Hooks{ - OnRequest: func(ctx context.Context) (webwire.Payload, error) { - // Extract request message and requesting client from the context - msg := ctx.Value(webwire.Msg).(webwire.Message) - - if currentStep == 2 { - // Expect the session to have been automatically restored - compareSessions(t, createdSession, msg.Client.Session) + webwire.Options{ + Hooks: webwire.Hooks{ + OnRequest: func(ctx context.Context) (webwire.Payload, error) { + // Extract request message and requesting client from the context + msg := ctx.Value(webwire.Msg).(webwire.Message) + + if currentStep == 2 { + // Expect the session to have been automatically restored + compareSessions(t, createdSession, msg.Client.Session) + return webwire.Payload{}, nil + } + + // Try to create a new session + if err := msg.Client.CreateSession(nil); err != nil { + return webwire.Payload{}, webwire.Error{ + Code: "INTERNAL_ERROR", + Message: fmt.Sprintf("Internal server error: %s", err), + } + } + + // Return the key of the newly created session return webwire.Payload{}, nil - } + }, + // Permanently store the session + OnSessionCreated: func(client *webwire.Client) error { + sessionStorage[client.Session.Key] = client.Session + return nil + }, + // Find session by key + OnSessionLookup: func(key string) (*webwire.Session, error) { + // Expect the key of the created session to be looked up + if key != createdSession.Key { + err := fmt.Errorf( + "Expected and looked up session keys differ: %s | %s", + createdSession.Key, + key, + ) + t.Fatalf("Session lookup mismatch: %s", err) + return nil, err + } - // Try to create a new session - if err := msg.Client.CreateSession(nil); err != nil { - return webwire.Payload{}, webwire.Error{ - Code: "INTERNAL_ERROR", - Message: fmt.Sprintf("Internal server error: %s", err), + if session, exists := sessionStorage[key]; exists { + return session, nil } - } - // Return the key of the newly created session - return webwire.Payload{}, nil - }, - // Permanently store the session - OnSessionCreated: func(client *webwire.Client) error { - sessionStorage[client.Session.Key] = client.Session - return nil - }, - // Find session by key - OnSessionLookup: func(key string) (*webwire.Session, error) { - // Expect the key of the created session to be looked up - if key != createdSession.Key { - err := fmt.Errorf( - "Expected and looked up session keys differ: %s | %s", + // Expect the session to be found + t.Fatalf( + "Expected session (%s) not found in: %v", createdSession.Key, - key, + sessionStorage, ) - t.Fatalf("Session lookup mismatch: %s", err) - return nil, err - } - - if session, exists := sessionStorage[key]; exists { - return session, nil - } - - // Expect the session to be found - t.Fatalf( - "Expected session (%s) not found in: %v", - createdSession.Key, - sessionStorage, - ) - return nil, nil + return nil, nil + }, + // Define dummy hook to enable sessions on this server + OnSessionClosed: func(_ *webwire.Client) error { return nil }, }, - // Define dummy hook to enable sessions on this server - OnSessionClosed: func(_ *webwire.Client) error { return nil }, }, ) diff --git a/test/clientConcurrentConnect_test.go b/test/clientConcurrentConnect_test.go index cc1d689..1ba5c36 100644 --- a/test/clientConcurrentConnect_test.go +++ b/test/clientConcurrentConnect_test.go @@ -16,7 +16,7 @@ func TestClientConcurrentConnect(t *testing.T) { finished := NewPending(concurrentAccessors, 2*time.Second, true) // Initialize webwire server - _, addr := setupServer(t, webwire.Hooks{}) + _, addr := setupServer(t, webwire.Options{}) // Initialize client client := webwireClient.NewClient( diff --git a/test/clientConcurrentRequest_test.go b/test/clientConcurrentRequest_test.go index aacde16..3e3fb10 100644 --- a/test/clientConcurrentRequest_test.go +++ b/test/clientConcurrentRequest_test.go @@ -19,10 +19,12 @@ func TestClientConcurrentRequest(t *testing.T) { // Initialize webwire server _, addr := setupServer( t, - webwire.Hooks{ - OnRequest: func(_ context.Context) (webwire.Payload, error) { - finished.Done() - return webwire.Payload{}, nil + webwire.Options{ + Hooks: webwire.Hooks{ + OnRequest: func(_ context.Context) (webwire.Payload, error) { + finished.Done() + return webwire.Payload{}, nil + }, }, }, ) diff --git a/test/clientConcurrentSignal_test.go b/test/clientConcurrentSignal_test.go index 4b48c73..723fa59 100644 --- a/test/clientConcurrentSignal_test.go +++ b/test/clientConcurrentSignal_test.go @@ -19,9 +19,11 @@ func TestClientConcurrentSignal(t *testing.T) { // Initialize webwire server _, addr := setupServer( t, - webwire.Hooks{ - OnSignal: func(_ context.Context) { - finished.Done() + webwire.Options{ + Hooks: webwire.Hooks{ + OnSignal: func(_ context.Context) { + finished.Done() + }, }, }, ) diff --git a/test/clientDisconnectedHook_test.go b/test/clientDisconnectedHook_test.go index 9cc0121..3ae18fc 100644 --- a/test/clientDisconnectedHook_test.go +++ b/test/clientDisconnectedHook_test.go @@ -18,20 +18,22 @@ func TestClientDisconnectedHook(t *testing.T) { // Initialize webwire server given only the request _, addr := setupServer( t, - webwire.Hooks{ - OnClientConnected: func(clt *webwire.Client) { - connectedClient = clt - }, - OnClientDisconnected: func(clt *webwire.Client) { - if clt != connectedClient { - t.Errorf( - "Connected and disconnecting clients don't match: "+ - "disconnecting: %p | connected: %p", - clt, - connectedClient, - ) - } - disconnectedHookCalled.Done() + webwire.Options{ + Hooks: webwire.Hooks{ + OnClientConnected: func(clt *webwire.Client) { + connectedClient = clt + }, + OnClientDisconnected: func(clt *webwire.Client) { + if clt != connectedClient { + t.Errorf( + "Connected and disconnecting clients don't match: "+ + "disconnecting: %p | connected: %p", + clt, + connectedClient, + ) + } + disconnectedHookCalled.Done() + }, }, }, ) diff --git a/test/clientInitiatedSessionDestruction_test.go b/test/clientInitiatedSessionDestruction_test.go index eaeb3a4..6da82ab 100644 --- a/test/clientInitiatedSessionDestruction_test.go +++ b/test/clientInitiatedSessionDestruction_test.go @@ -30,53 +30,55 @@ func TestClientInitiatedSessionDestruction(t *testing.T) { // Initialize webwire server _, addr := setupServer( t, - webwire.Hooks{ - OnRequest: func(ctx context.Context) (webwire.Payload, error) { - // Extract request message and requesting client from the context - msg := ctx.Value(webwire.Msg).(webwire.Message) - - // On step 2 - verify session creation and correctness - if currentStep == 2 { - compareSessions(t, createdSession, msg.Client.Session) - if string(msg.Payload.Data) != msg.Client.Session.Key { - t.Errorf( - "Clients session key doesn't match: "+ - "client: '%s' | server: '%s'", - string(msg.Payload.Data), - msg.Client.Session.Key, - ) + webwire.Options{ + Hooks: webwire.Hooks{ + OnRequest: func(ctx context.Context) (webwire.Payload, error) { + // Extract request message and requesting client from the context + msg := ctx.Value(webwire.Msg).(webwire.Message) + + // On step 2 - verify session creation and correctness + if currentStep == 2 { + compareSessions(t, createdSession, msg.Client.Session) + if string(msg.Payload.Data) != msg.Client.Session.Key { + t.Errorf( + "Clients session key doesn't match: "+ + "client: '%s' | server: '%s'", + string(msg.Payload.Data), + msg.Client.Session.Key, + ) + } + return webwire.Payload{}, nil } - return webwire.Payload{}, nil - } - // On step 4 - verify session destruction - if currentStep == 4 { - if msg.Client.Session != nil { - t.Errorf( - "Expected the session to be destroyed, got: %v", - msg.Client.Session, - ) + // On step 4 - verify session destruction + if currentStep == 4 { + if msg.Client.Session != nil { + t.Errorf( + "Expected the session to be destroyed, got: %v", + msg.Client.Session, + ) + } + return webwire.Payload{}, nil } - return webwire.Payload{}, nil - } - // On step 1 - authenticate and create a new session - if err := msg.Client.CreateSession(nil); err != nil { - return webwire.Payload{}, webwire.Error{ - Code: "INTERNAL_ERROR", - Message: fmt.Sprintf("Internal server error: %s", err), + // On step 1 - authenticate and create a new session + if err := msg.Client.CreateSession(nil); err != nil { + return webwire.Payload{}, webwire.Error{ + Code: "INTERNAL_ERROR", + Message: fmt.Sprintf("Internal server error: %s", err), + } } - } - // Return the key of the newly created session - return webwire.Payload{ - Data: []byte(msg.Client.Session.Key), - }, nil + // Return the key of the newly created session + return webwire.Payload{ + Data: []byte(msg.Client.Session.Key), + }, nil + }, + // Define dummy hooks to enable sessions on this server + OnSessionCreated: func(_ *webwire.Client) error { return nil }, + OnSessionLookup: func(_ string) (*webwire.Session, error) { return nil, nil }, + OnSessionClosed: func(_ *webwire.Client) error { return nil }, }, - // Define dummy hooks to enable sessions on this server - OnSessionCreated: func(_ *webwire.Client) error { return nil }, - OnSessionLookup: func(_ string) (*webwire.Session, error) { return nil, nil }, - OnSessionClosed: func(_ *webwire.Client) error { return nil }, }, ) diff --git a/test/clientIsConnected_test.go b/test/clientIsConnected_test.go index bb83a7f..ef5a3c6 100644 --- a/test/clientIsConnected_test.go +++ b/test/clientIsConnected_test.go @@ -12,10 +12,7 @@ import ( // TestClientIsConnected verifies correct client.IsConnected reporting func TestClientIsConnected(t *testing.T) { // Initialize webwire server given only the request - _, addr := setupServer( - t, - webwire.Hooks{}, - ) + _, addr := setupServer(t, webwire.Options{}) // Initialize client client := webwireClient.NewClient( diff --git a/test/clientOfflineSessionClosure_test.go b/test/clientOfflineSessionClosure_test.go index cf57d92..a642af1 100644 --- a/test/clientOfflineSessionClosure_test.go +++ b/test/clientOfflineSessionClosure_test.go @@ -21,62 +21,64 @@ func TestClientOfflineSessionClosure(t *testing.T) { // Initialize webwire server _, addr := setupServer( t, - webwire.Hooks{ - OnRequest: func(ctx context.Context) (webwire.Payload, error) { - // Extract request message and requesting client from the context - msg := ctx.Value(webwire.Msg).(webwire.Message) - - if currentStep == 2 { - // Expect the session to have been automatically restored - if msg.Client.Session != nil { - t.Errorf("Expected client to be anonymous") + webwire.Options{ + Hooks: webwire.Hooks{ + OnRequest: func(ctx context.Context) (webwire.Payload, error) { + // Extract request message and requesting client from the context + msg := ctx.Value(webwire.Msg).(webwire.Message) + + if currentStep == 2 { + // Expect the session to have been automatically restored + if msg.Client.Session != nil { + t.Errorf("Expected client to be anonymous") + } + return webwire.Payload{}, nil } + + // Try to create a new session + if err := msg.Client.CreateSession(nil); err != nil { + return webwire.Payload{}, webwire.Error{ + Code: "INTERNAL_ERROR", + Message: fmt.Sprintf("Internal server error: %s", err), + } + } + + // Return the key of the newly created session return webwire.Payload{}, nil - } + }, + // Permanently store the session + OnSessionCreated: func(client *webwire.Client) error { + sessionStorage[client.Session.Key] = client.Session + return nil + }, + // Find session by key + OnSessionLookup: func(key string) (*webwire.Session, error) { + // Expect the key of the created session to be looked up + if key != createdSession.Key { + err := fmt.Errorf( + "Expected and looked up session keys differ: %s | %s", + createdSession.Key, + key, + ) + t.Fatalf("Session lookup mismatch: %s", err) + return nil, err + } - // Try to create a new session - if err := msg.Client.CreateSession(nil); err != nil { - return webwire.Payload{}, webwire.Error{ - Code: "INTERNAL_ERROR", - Message: fmt.Sprintf("Internal server error: %s", err), + if session, exists := sessionStorage[key]; exists { + return session, nil } - } - // Return the key of the newly created session - return webwire.Payload{}, nil - }, - // Permanently store the session - OnSessionCreated: func(client *webwire.Client) error { - sessionStorage[client.Session.Key] = client.Session - return nil - }, - // Find session by key - OnSessionLookup: func(key string) (*webwire.Session, error) { - // Expect the key of the created session to be looked up - if key != createdSession.Key { - err := fmt.Errorf( - "Expected and looked up session keys differ: %s | %s", + // Expect the session to be found + t.Fatalf( + "Expected session (%s) not found in: %v", createdSession.Key, - key, + sessionStorage, ) - t.Fatalf("Session lookup mismatch: %s", err) - return nil, err - } - - if session, exists := sessionStorage[key]; exists { - return session, nil - } - - // Expect the session to be found - t.Fatalf( - "Expected session (%s) not found in: %v", - createdSession.Key, - sessionStorage, - ) - return nil, nil + return nil, nil + }, + // Define dummy hook to enable sessions on this server + OnSessionClosed: func(_ *webwire.Client) error { return nil }, }, - // Define dummy hook to enable sessions on this server - OnSessionClosed: func(_ *webwire.Client) error { return nil }, }, ) diff --git a/test/clientOnSessionClosed_test.go b/test/clientOnSessionClosed_test.go index 0556077..0d7810f 100644 --- a/test/clientOnSessionClosed_test.go +++ b/test/clientOnSessionClosed_test.go @@ -19,38 +19,40 @@ func TestClientOnSessionClosed(t *testing.T) { // Initialize webwire server _, addr := setupServer( t, - webwire.Hooks{ - OnRequest: func(ctx context.Context) (webwire.Payload, error) { - // Extract request message and requesting client from the context - msg := ctx.Value(webwire.Msg).(webwire.Message) + webwire.Options{ + Hooks: webwire.Hooks{ + OnRequest: func(ctx context.Context) (webwire.Payload, error) { + // Extract request message and requesting client from the context + msg := ctx.Value(webwire.Msg).(webwire.Message) - // Try to create a new session - if err := msg.Client.CreateSession(nil); err != nil { - return webwire.Payload{}, webwire.Error{ - Code: "INTERNAL_ERROR", - Message: fmt.Sprintf("Internal server error: %s", err), + // Try to create a new session + if err := msg.Client.CreateSession(nil); err != nil { + return webwire.Payload{}, webwire.Error{ + Code: "INTERNAL_ERROR", + Message: fmt.Sprintf("Internal server error: %s", err), + } } - } - go func() { - // Wait until the authentication request is finished - if err := authenticated.Wait(); err != nil { - t.Errorf("Authentication timed out") - return - } + go func() { + // Wait until the authentication request is finished + if err := authenticated.Wait(); err != nil { + t.Errorf("Authentication timed out") + return + } - // Close the session - if err := msg.Client.CloseSession(); err != nil { - t.Errorf("Couldn't close session: %s", err) - } - }() + // Close the session + if err := msg.Client.CloseSession(); err != nil { + t.Errorf("Couldn't close session: %s", err) + } + }() - return webwire.Payload{}, nil + return webwire.Payload{}, nil + }, + // Define dummy hooks to enable sessions on this server + OnSessionCreated: func(_ *webwire.Client) error { return nil }, + OnSessionLookup: func(_ string) (*webwire.Session, error) { return nil, nil }, + OnSessionClosed: func(_ *webwire.Client) error { return nil }, }, - // Define dummy hooks to enable sessions on this server - OnSessionCreated: func(_ *webwire.Client) error { return nil }, - OnSessionLookup: func(_ string) (*webwire.Session, error) { return nil, nil }, - OnSessionClosed: func(_ *webwire.Client) error { return nil }, }, ) diff --git a/test/clientOnSessionCreated_test.go b/test/clientOnSessionCreated_test.go index 7d77838..2c16fc2 100644 --- a/test/clientOnSessionCreated_test.go +++ b/test/clientOnSessionCreated_test.go @@ -20,24 +20,26 @@ func TestClientOnSessionCreated(t *testing.T) { // Initialize webwire server _, addr := setupServer( t, - webwire.Hooks{ - OnRequest: func(ctx context.Context) (webwire.Payload, error) { - // Extract request message and requesting client from the context - msg := ctx.Value(webwire.Msg).(webwire.Message) + webwire.Options{ + Hooks: webwire.Hooks{ + OnRequest: func(ctx context.Context) (webwire.Payload, error) { + // Extract request message and requesting client from the context + msg := ctx.Value(webwire.Msg).(webwire.Message) - // Try to create a new session - if err := msg.Client.CreateSession(nil); err != nil { - return webwire.Payload{}, webwire.Error{ - Code: "INTERNAL_ERROR", - Message: fmt.Sprintf("Internal server error: %s", err), + // Try to create a new session + if err := msg.Client.CreateSession(nil); err != nil { + return webwire.Payload{}, webwire.Error{ + Code: "INTERNAL_ERROR", + Message: fmt.Sprintf("Internal server error: %s", err), + } } - } - return webwire.Payload{}, nil + return webwire.Payload{}, nil + }, + // Define dummy hooks to enable sessions on this server + OnSessionCreated: func(_ *webwire.Client) error { return nil }, + OnSessionLookup: func(_ string) (*webwire.Session, error) { return nil, nil }, + OnSessionClosed: func(_ *webwire.Client) error { return nil }, }, - // Define dummy hooks to enable sessions on this server - OnSessionCreated: func(_ *webwire.Client) error { return nil }, - OnSessionLookup: func(_ string) (*webwire.Session, error) { return nil, nil }, - OnSessionClosed: func(_ *webwire.Client) error { return nil }, }, ) diff --git a/test/clientRequestError_test.go b/test/clientRequestError_test.go index 7adbaf7..e982b9f 100644 --- a/test/clientRequestError_test.go +++ b/test/clientRequestError_test.go @@ -25,10 +25,12 @@ func TestClientRequestError(t *testing.T) { // Initialize webwire server given only the request _, addr := setupServer( t, - webwire.Hooks{ - OnRequest: func(_ context.Context) (webwire.Payload, error) { - // Fail the request by returning an error - return webwire.Payload{}, expectedReplyError + webwire.Options{ + Hooks: webwire.Hooks{ + OnRequest: func(_ context.Context) (webwire.Payload, error) { + // Fail the request by returning an error + return webwire.Payload{}, expectedReplyError + }, }, }, ) diff --git a/test/clientRequestRegisterOnReply_test.go b/test/clientRequestRegisterOnReply_test.go index ba597ed..c6d1b05 100644 --- a/test/clientRequestRegisterOnReply_test.go +++ b/test/clientRequestRegisterOnReply_test.go @@ -18,18 +18,20 @@ func TestClientRequestRegisterOnReply(t *testing.T) { // Initialize webwire server given only the request _, addr := setupServer( t, - webwire.Hooks{ - OnRequest: func(ctx context.Context) (webwire.Payload, error) { - // Verify pending requests - pendingReqs := client.PendingRequests() - if pendingReqs != 1 { - t.Errorf("Unexpected pending requests: %d", pendingReqs) - return webwire.Payload{}, nil - } + webwire.Options{ + Hooks: webwire.Hooks{ + OnRequest: func(ctx context.Context) (webwire.Payload, error) { + // Verify pending requests + pendingReqs := client.PendingRequests() + if pendingReqs != 1 { + t.Errorf("Unexpected pending requests: %d", pendingReqs) + return webwire.Payload{}, nil + } - // Wait until the request times out - time.Sleep(300 * time.Millisecond) - return webwire.Payload{}, nil + // Wait until the request times out + time.Sleep(300 * time.Millisecond) + return webwire.Payload{}, nil + }, }, }, ) diff --git a/test/clientRequestRegisterOnTimeout_test.go b/test/clientRequestRegisterOnTimeout_test.go index c6acb1e..ce25247 100644 --- a/test/clientRequestRegisterOnTimeout_test.go +++ b/test/clientRequestRegisterOnTimeout_test.go @@ -18,18 +18,20 @@ func TestClientRequestRegisterOnTimeout(t *testing.T) { // Initialize webwire server given only the request _, addr := setupServer( t, - webwire.Hooks{ - OnRequest: func(ctx context.Context) (webwire.Payload, error) { - // Verify pending requests - pendingReqs := client.PendingRequests() - if pendingReqs != 1 { - t.Errorf("Unexpected pending requests: %d", pendingReqs) - return webwire.Payload{}, nil - } + webwire.Options{ + Hooks: webwire.Hooks{ + OnRequest: func(ctx context.Context) (webwire.Payload, error) { + // Verify pending requests + pendingReqs := client.PendingRequests() + if pendingReqs != 1 { + t.Errorf("Unexpected pending requests: %d", pendingReqs) + return webwire.Payload{}, nil + } - // Wait until the request times out - time.Sleep(300 * time.Millisecond) - return webwire.Payload{}, nil + // Wait until the request times out + time.Sleep(300 * time.Millisecond) + return webwire.Payload{}, nil + }, }, }, ) diff --git a/test/clientRequest_test.go b/test/clientRequest_test.go index cdf31e9..fecedd5 100644 --- a/test/clientRequest_test.go +++ b/test/clientRequest_test.go @@ -25,19 +25,21 @@ func TestClientRequest(t *testing.T) { // Initialize webwire server given only the request _, addr := setupServer( t, - webwire.Hooks{ - OnRequest: func(ctx context.Context) (webwire.Payload, error) { - // Extract request message from the context - msg := ctx.Value(webwire.Msg).(webwire.Message) + webwire.Options{ + Hooks: webwire.Hooks{ + OnRequest: func(ctx context.Context) (webwire.Payload, error) { + // Extract request message from the context + msg := ctx.Value(webwire.Msg).(webwire.Message) - // Verify request payload - comparePayload( - t, - "client request", - expectedRequestPayload, - msg.Payload, - ) - return expectedReplyPayload, nil + // Verify request payload + comparePayload( + t, + "client request", + expectedRequestPayload, + msg.Payload, + ) + return expectedReplyPayload, nil + }, }, }, ) diff --git a/test/clientSessionInfo_test.go b/test/clientSessionInfo_test.go index 87da0a2..dfcc093 100644 --- a/test/clientSessionInfo_test.go +++ b/test/clientSessionInfo_test.go @@ -28,44 +28,46 @@ func TestClientSessionInfo(t *testing.T) { // Initialize webwire server _, addr := setupServer( t, - webwire.Hooks{ - OnRequest: func(ctx context.Context) (webwire.Payload, error) { - // Extract request message and requesting client from the context - msg := ctx.Value(webwire.Msg).(webwire.Message) + webwire.Options{ + Hooks: webwire.Hooks{ + OnRequest: func(ctx context.Context) (webwire.Payload, error) { + // Extract request message and requesting client from the context + msg := ctx.Value(webwire.Msg).(webwire.Message) - // Try to create a new session - if err := msg.Client.CreateSession(struct { - SampleBool bool `json:"bool"` - SampleString string `json:"string"` - SampleInt uint32 `json:"int"` - SampleNumber float64 `json:"number"` - SampleArray []string `json:"array"` - SampleStruct struct { - SampleString string `json:"struct_string"` - } `json:"struct"` - }{ - SampleBool: expectedBool, - SampleString: expectedString, - SampleInt: expectedInt, - SampleNumber: expectedNumber, - SampleArray: expectedArray, - SampleStruct: struct { - SampleString string `json:"struct_string"` + // Try to create a new session + if err := msg.Client.CreateSession(struct { + SampleBool bool `json:"bool"` + SampleString string `json:"string"` + SampleInt uint32 `json:"int"` + SampleNumber float64 `json:"number"` + SampleArray []string `json:"array"` + SampleStruct struct { + SampleString string `json:"struct_string"` + } `json:"struct"` }{ - SampleString: expectedStruct.SampleString, - }, - }); err != nil { - return webwire.Payload{}, webwire.Error{ - Code: "INTERNAL_ERROR", - Message: fmt.Sprintf("Internal server error: %s", err), + SampleBool: expectedBool, + SampleString: expectedString, + SampleInt: expectedInt, + SampleNumber: expectedNumber, + SampleArray: expectedArray, + SampleStruct: struct { + SampleString string `json:"struct_string"` + }{ + SampleString: expectedStruct.SampleString, + }, + }); err != nil { + return webwire.Payload{}, webwire.Error{ + Code: "INTERNAL_ERROR", + Message: fmt.Sprintf("Internal server error: %s", err), + } } - } - return webwire.Payload{}, nil + return webwire.Payload{}, nil + }, + // Define dummy hooks to enable sessions on this server + OnSessionCreated: func(_ *webwire.Client) error { return nil }, + OnSessionLookup: func(_ string) (*webwire.Session, error) { return nil, nil }, + OnSessionClosed: func(_ *webwire.Client) error { return nil }, }, - // Define dummy hooks to enable sessions on this server - OnSessionCreated: func(_ *webwire.Client) error { return nil }, - OnSessionLookup: func(_ string) (*webwire.Session, error) { return nil, nil }, - OnSessionClosed: func(_ *webwire.Client) error { return nil }, }, ) diff --git a/test/clientSessionRestoration_test.go b/test/clientSessionRestoration_test.go index 7a0889d..42d4890 100644 --- a/test/clientSessionRestoration_test.go +++ b/test/clientSessionRestoration_test.go @@ -21,60 +21,62 @@ func TestClientSessionRestoration(t *testing.T) { // Initialize webwire server _, addr := setupServer( t, - webwire.Hooks{ - OnRequest: func(ctx context.Context) (webwire.Payload, error) { - // Extract request message and requesting client from the context - msg := ctx.Value(webwire.Msg).(webwire.Message) - - if currentStep == 2 { - // Expect the session to have been automatically restored - compareSessions(t, createdSession, msg.Client.Session) + webwire.Options{ + Hooks: webwire.Hooks{ + OnRequest: func(ctx context.Context) (webwire.Payload, error) { + // Extract request message and requesting client from the context + msg := ctx.Value(webwire.Msg).(webwire.Message) + + if currentStep == 2 { + // Expect the session to have been automatically restored + compareSessions(t, createdSession, msg.Client.Session) + return webwire.Payload{}, nil + } + + // Try to create a new session + if err := msg.Client.CreateSession(nil); err != nil { + return webwire.Payload{}, webwire.Error{ + Code: "INTERNAL_ERROR", + Message: fmt.Sprintf("Internal server error: %s", err), + } + } + + // Return the key of the newly created session return webwire.Payload{}, nil - } + }, + // Permanently store the session + OnSessionCreated: func(client *webwire.Client) error { + sessionStorage[client.Session.Key] = client.Session + return nil + }, + // Find session by key + OnSessionLookup: func(key string) (*webwire.Session, error) { + // Expect the key of the created session to be looked up + if key != createdSession.Key { + err := fmt.Errorf( + "Expected and looked up session keys differ: %s | %s", + createdSession.Key, + key, + ) + t.Fatalf("Session lookup mismatch: %s", err) + return nil, err + } - // Try to create a new session - if err := msg.Client.CreateSession(nil); err != nil { - return webwire.Payload{}, webwire.Error{ - Code: "INTERNAL_ERROR", - Message: fmt.Sprintf("Internal server error: %s", err), + if session, exists := sessionStorage[key]; exists { + return session, nil } - } - // Return the key of the newly created session - return webwire.Payload{}, nil - }, - // Permanently store the session - OnSessionCreated: func(client *webwire.Client) error { - sessionStorage[client.Session.Key] = client.Session - return nil - }, - // Find session by key - OnSessionLookup: func(key string) (*webwire.Session, error) { - // Expect the key of the created session to be looked up - if key != createdSession.Key { - err := fmt.Errorf( - "Expected and looked up session keys differ: %s | %s", + // Expect the session to be found + t.Fatalf( + "Expected session (%s) not found in: %v", createdSession.Key, - key, + sessionStorage, ) - t.Fatalf("Session lookup mismatch: %s", err) - return nil, err - } - - if session, exists := sessionStorage[key]; exists { - return session, nil - } - - // Expect the session to be found - t.Fatalf( - "Expected session (%s) not found in: %v", - createdSession.Key, - sessionStorage, - ) - return nil, nil + return nil, nil + }, + // Define dummy hook to enable sessions on this server + OnSessionClosed: func(_ *webwire.Client) error { return nil }, }, - // Define dummy hook to enable sessions on this server - OnSessionClosed: func(_ *webwire.Client) error { return nil }, }, ) diff --git a/test/clientSignal_test.go b/test/clientSignal_test.go index bfc78f6..21de9b6 100644 --- a/test/clientSignal_test.go +++ b/test/clientSignal_test.go @@ -22,21 +22,23 @@ func TestClientSignal(t *testing.T) { // Initialize webwire server given only the signal handler _, addr := setupServer( t, - webwire.Hooks{ - OnSignal: func(ctx context.Context) { - // Extract signal message from the context - msg := ctx.Value(webwire.Msg).(webwire.Message) + webwire.Options{ + Hooks: webwire.Hooks{ + OnSignal: func(ctx context.Context) { + // Extract signal message from the context + msg := ctx.Value(webwire.Msg).(webwire.Message) - // Verify signal payload - comparePayload( - t, - "client signal", - expectedSignalPayload, - msg.Payload, - ) + // Verify signal payload + comparePayload( + t, + "client signal", + expectedSignalPayload, + msg.Payload, + ) - // Synchronize, notify signal arrival - signalArrived.Done() + // Synchronize, notify signal arrival + signalArrived.Done() + }, }, }, ) diff --git a/test/emptyReplyUtf16_test.go b/test/emptyReplyUtf16_test.go index 7b605ce..f835905 100644 --- a/test/emptyReplyUtf16_test.go +++ b/test/emptyReplyUtf16_test.go @@ -15,12 +15,14 @@ func TestEmptyReplyUtf16(t *testing.T) { // Initialize webwire server given only the request _, addr := setupServer( t, - wwr.Hooks{ - OnRequest: func(_ context.Context) (wwr.Payload, error) { - // Return empty reply - return wwr.Payload{ - Encoding: wwr.EncodingUtf16, - }, nil + wwr.Options{ + Hooks: wwr.Hooks{ + OnRequest: func(_ context.Context) (wwr.Payload, error) { + // Return empty reply + return wwr.Payload{ + Encoding: wwr.EncodingUtf16, + }, nil + }, }, }, ) diff --git a/test/emptyReply_test.go b/test/emptyReply_test.go index cb023bb..5765afe 100644 --- a/test/emptyReply_test.go +++ b/test/emptyReply_test.go @@ -15,10 +15,12 @@ func TestEmptyReply(t *testing.T) { // Initialize webwire server given only the request _, addr := setupServer( t, - wwr.Hooks{ - OnRequest: func(_ context.Context) (wwr.Payload, error) { - // Return empty reply - return wwr.Payload{}, nil + wwr.Options{ + Hooks: wwr.Hooks{ + OnRequest: func(_ context.Context) (wwr.Payload, error) { + // Return empty reply + return wwr.Payload{}, nil + }, }, }, ) diff --git a/test/endpointMetadata_test.go b/test/endpointMetadata_test.go index a070d79..900afb8 100644 --- a/test/endpointMetadata_test.go +++ b/test/endpointMetadata_test.go @@ -15,10 +15,7 @@ func TestEndpointMetadata(t *testing.T) { expectedVersion := "1.1" // Initialize webwire server - _, addr := setupServer( - t, - webwire.Hooks{}, - ) + _, addr := setupServer(t, webwire.Options{}) // Initialize HTTP client var httpClient = &http.Client{ diff --git a/test/maxConcSessConn_test.go b/test/maxConcSessConn_test.go new file mode 100644 index 0000000..1587516 --- /dev/null +++ b/test/maxConcSessConn_test.go @@ -0,0 +1,93 @@ +package test + +import ( + "os" + "testing" + "time" + + wwr "github.com/qbeon/webwire-go" + wwrClient "github.com/qbeon/webwire-go/client" +) + +// TestMaxConcSessConn tests 4 maximum concurrent connections of a session +func TestMaxConcSessConn(t *testing.T) { + sessionStorage := make(map[string]*wwr.Session) + + var sessionKey string + concurrentConns := uint(4) + + // Initialize server + _, addr := setupServer( + t, + wwr.Options{ + MaxSessionConnections: concurrentConns, + Hooks: wwr.Hooks{ + OnClientConnected: func(client *wwr.Client) { + // Created the session for the first connecting client only + if len(sessionKey) < 1 { + if err := client.CreateSession(nil); err != nil { + t.Errorf("Unexpected error during session creation: %s", err) + } + sessionKey = client.Session.Key + } + }, + // Permanently store the session + OnSessionCreated: func(client *wwr.Client) error { + sessionStorage[client.Session.Key] = client.Session + return nil + }, + // Find session by key + OnSessionLookup: func(key string) (*wwr.Session, error) { + if session, exists := sessionStorage[key]; exists { + return session, nil + } + return nil, nil + }, + // Define dummy hook to enable sessions on this server + OnSessionClosed: func(_ *wwr.Client) error { return nil }, + }, + }, + ) + + // Initialize client + clients := make([]*wwrClient.Client, concurrentConns) + for i := uint(0); i < concurrentConns; i++ { + client := wwrClient.NewClient( + addr, + wwrClient.Hooks{}, + 5*time.Second, + os.Stdout, + os.Stderr, + ) + clients[i] = &client + + if err := client.Connect(); err != nil { + t.Fatalf("Couldn't connect client: %s", err) + } + + // Restore the session for all clients except the first one + if i > 0 { + if err := client.RestoreSession([]byte(sessionKey)); err != nil { + t.Fatalf("Unexpected error during manual session restoration: %s", err) + } + } + } + + // Ensure that the last superfluous client is rejected + superflousClient := wwrClient.NewClient( + addr, + wwrClient.Hooks{}, + 5*time.Second, + os.Stdout, + os.Stderr, + ) + + if err := superflousClient.Connect(); err != nil { + t.Fatalf("Couldn't connect superfluous client: %s", err) + } + + // Try to restore the session and expect this operation to fail due to reached limit + if err := superflousClient.RestoreSession([]byte(sessionKey)); err == nil { + t.Fatalf("Expected an error during superfluous client manual session restoration") + } +} diff --git a/test/requestNamespaces_test.go b/test/requestNamespaces_test.go index 5db2dc9..1540da8 100644 --- a/test/requestNamespaces_test.go +++ b/test/requestNamespaces_test.go @@ -24,21 +24,23 @@ func TestRequestNamespaces(t *testing.T) { // Initialize server _, addr := setupServer( t, - webwire.Hooks{ - OnRequest: func(ctx context.Context) (webwire.Payload, error) { - msg := ctx.Value(webwire.Msg).(webwire.Message) + webwire.Options{ + Hooks: webwire.Hooks{ + OnRequest: func(ctx context.Context) (webwire.Payload, error) { + msg := ctx.Value(webwire.Msg).(webwire.Message) - if currentStep == 1 && msg.Name != "" { - t.Errorf("Expected unnamed request, got: '%s'", msg.Name) - } - if currentStep == 2 && msg.Name != shortestPossibleName { - t.Errorf("Expected shortest possible request name, got: '%s'", msg.Name) - } - if currentStep == 3 && msg.Name != longestPossibleName { - t.Errorf("Expected longest possible request name, got: '%s'", msg.Name) - } + if currentStep == 1 && msg.Name != "" { + t.Errorf("Expected unnamed request, got: '%s'", msg.Name) + } + if currentStep == 2 && msg.Name != shortestPossibleName { + t.Errorf("Expected shortest possible request name, got: '%s'", msg.Name) + } + if currentStep == 3 && msg.Name != longestPossibleName { + t.Errorf("Expected longest possible request name, got: '%s'", msg.Name) + } - return webwire.Payload{}, nil + return webwire.Payload{}, nil + }, }, }, ) diff --git a/test/serverInitiatedSessionDestruction_test.go b/test/serverInitiatedSessionDestruction_test.go index c53a27c..8684ed3 100644 --- a/test/serverInitiatedSessionDestruction_test.go +++ b/test/serverInitiatedSessionDestruction_test.go @@ -29,76 +29,78 @@ func TestServerInitiatedSessionDestruction(t *testing.T) { // Initialize webwire server _, addr := setupServer( t, - webwire.Hooks{ - OnRequest: func(ctx context.Context) (webwire.Payload, error) { - // Extract request message and requesting client from the context - msg := ctx.Value(webwire.Msg).(webwire.Message) - - // On step 2 - verify session creation and correctness - if currentStep == 2 { - compareSessions(t, createdSession, msg.Client.Session) - if string(msg.Payload.Data) != msg.Client.Session.Key { - t.Errorf( - "Clients session key doesn't match: "+ - "client: '%s' | server: '%s'", - string(msg.Payload.Data), - msg.Client.Session.Key, - ) + webwire.Options{ + Hooks: webwire.Hooks{ + OnRequest: func(ctx context.Context) (webwire.Payload, error) { + // Extract request message and requesting client from the context + msg := ctx.Value(webwire.Msg).(webwire.Message) + + // On step 2 - verify session creation and correctness + if currentStep == 2 { + compareSessions(t, createdSession, msg.Client.Session) + if string(msg.Payload.Data) != msg.Client.Session.Key { + t.Errorf( + "Clients session key doesn't match: "+ + "client: '%s' | server: '%s'", + string(msg.Payload.Data), + msg.Client.Session.Key, + ) + } + return webwire.Payload{}, nil } - return webwire.Payload{}, nil - } - // on step 3 - close session and verify its destruction - if currentStep == 3 { - /***********************************************************\ - Server-side session destruction initiation - \***********************************************************/ - // Attempt to destroy this clients session - // on the end of the first step - err := msg.Client.CloseSession() - if err != nil { - t.Errorf( - "Couldn't close the active session on the server: %s", - err, - ) + // on step 3 - close session and verify its destruction + if currentStep == 3 { + /***********************************************************\ + Server-side session destruction initiation + \***********************************************************/ + // Attempt to destroy this clients session + // on the end of the first step + err := msg.Client.CloseSession() + if err != nil { + t.Errorf( + "Couldn't close the active session on the server: %s", + err, + ) + } + + // Verify destruction + if msg.Client.Session != nil { + t.Errorf( + "Expected the session to be destroyed, got: %v", + msg.Client.Session, + ) + } + + return webwire.Payload{}, nil } - // Verify destruction - if msg.Client.Session != nil { - t.Errorf( - "Expected the session to be destroyed, got: %v", - msg.Client.Session, - ) + // On step 4 - verify session destruction + if currentStep == 4 { + if msg.Client.Session != nil { + t.Errorf("Expected the session to be destroyed") + } + return webwire.Payload{}, nil } - return webwire.Payload{}, nil - } - - // On step 4 - verify session destruction - if currentStep == 4 { - if msg.Client.Session != nil { - t.Errorf("Expected the session to be destroyed") + // On step 1 - authenticate and create a new session + if err := msg.Client.CreateSession(nil); err != nil { + return webwire.Payload{}, webwire.Error{ + Code: "INTERNAL_ERROR", + Message: fmt.Sprintf("Internal server error: %s", err), + } } - return webwire.Payload{}, nil - } - - // On step 1 - authenticate and create a new session - if err := msg.Client.CreateSession(nil); err != nil { - return webwire.Payload{}, webwire.Error{ - Code: "INTERNAL_ERROR", - Message: fmt.Sprintf("Internal server error: %s", err), - } - } - // Return the key of the newly created session - return webwire.Payload{ - Data: []byte(msg.Client.Session.Key), - }, nil + // Return the key of the newly created session + return webwire.Payload{ + Data: []byte(msg.Client.Session.Key), + }, nil + }, + // Define dummy hooks to enable sessions on this server + OnSessionCreated: func(_ *webwire.Client) error { return nil }, + OnSessionLookup: func(_ string) (*webwire.Session, error) { return nil, nil }, + OnSessionClosed: func(_ *webwire.Client) error { return nil }, }, - // Define dummy hooks to enable sessions on this server - OnSessionCreated: func(_ *webwire.Client) error { return nil }, - OnSessionLookup: func(_ string) (*webwire.Session, error) { return nil, nil }, - OnSessionClosed: func(_ *webwire.Client) error { return nil }, }, ) diff --git a/test/serverSignal_test.go b/test/serverSignal_test.go index ba1c5e7..1fae4fe 100644 --- a/test/serverSignal_test.go +++ b/test/serverSignal_test.go @@ -24,12 +24,14 @@ func TestServerSignal(t *testing.T) { go func() { _, addr = setupServer( t, - webwire.Hooks{ - OnClientConnected: func(client *webwire.Client) { - // Send signal - if err := client.Signal("", expectedSignalPayload); err != nil { - t.Fatalf("Couldn't send signal to client: %s", err) - } + webwire.Options{ + Hooks: webwire.Hooks{ + OnClientConnected: func(client *webwire.Client) { + // Send signal + if err := client.Signal("", expectedSignalPayload); err != nil { + t.Fatalf("Couldn't send signal to client: %s", err) + } + }, }, }, ) diff --git a/test/signalNamespaces_test.go b/test/signalNamespaces_test.go index 43c803b..38e419e 100644 --- a/test/signalNamespaces_test.go +++ b/test/signalNamespaces_test.go @@ -27,28 +27,30 @@ func TestSignalNamespaces(t *testing.T) { // Initialize server _, addr := setupServer( t, - webwire.Hooks{ - OnSignal: func(ctx context.Context) { - msg := ctx.Value(webwire.Msg).(webwire.Message) + webwire.Options{ + Hooks: webwire.Hooks{ + OnSignal: func(ctx context.Context) { + msg := ctx.Value(webwire.Msg).(webwire.Message) - if currentStep == 1 && msg.Name != "" { - t.Errorf("Expected unnamed signal, got: '%s'", msg.Name) - } - if currentStep == 2 && msg.Name != shortestPossibleName { - t.Errorf("Expected shortest possible signal name, got: '%s'", msg.Name) - } - if currentStep == 3 && msg.Name != longestPossibleName { - t.Errorf("Expected longest possible signal name, got: '%s'", msg.Name) - } + if currentStep == 1 && msg.Name != "" { + t.Errorf("Expected unnamed signal, got: '%s'", msg.Name) + } + if currentStep == 2 && msg.Name != shortestPossibleName { + t.Errorf("Expected shortest possible signal name, got: '%s'", msg.Name) + } + if currentStep == 3 && msg.Name != longestPossibleName { + t.Errorf("Expected longest possible signal name, got: '%s'", msg.Name) + } - switch currentStep { - case 1: - unnamedSignalArrived.Done() - case 2: - shortestNameSignalArrived.Done() - case 3: - longestNameSignalArrived.Done() - } + switch currentStep { + case 1: + unnamedSignalArrived.Done() + case 2: + shortestNameSignalArrived.Done() + case 3: + longestNameSignalArrived.Done() + } + }, }, }, ) diff --git a/test/test.go b/test/test.go index b524978..7a75f60 100644 --- a/test/test.go +++ b/test/test.go @@ -11,17 +11,12 @@ import ( // setupServer helps setting up // and launching the server together with the hosting http server -func setupServer( - t *testing.T, - hooks wwr.Hooks, -) (*wwr.Server, string) { +func setupServer(t *testing.T, opts wwr.Options) (*wwr.Server, string) { // Setup headed server on arbitrary port - srv, _, addr, run, err := wwr.SetupServer(wwr.Options{ - Addr: "127.0.0.1:0", - Hooks: hooks, - WarnLog: os.Stdout, - ErrorLog: os.Stderr, - }) + opts.Addr = "127.0.0.1:0" + opts.WarnLog = os.Stdout + opts.ErrorLog = os.Stderr + srv, _, addr, run, err := wwr.SetupServer(opts) if err != nil { t.Fatalf("Failed setting up server instance: %s", err) }