Skip to content
This repository has been archived by the owner on May 1, 2020. It is now read-only.

Commit

Permalink
Implement multi-connection sessions
Browse files Browse the repository at this point in the history
Implement the option to enable multi-connection sessions.
Add test case to verify correct behavior of the newly created option.
Change the way tests setup their servers to allow passing options.
  • Loading branch information
romshark committed Mar 14, 2018
1 parent 2afd3b6 commit 144c6ad
Show file tree
Hide file tree
Showing 32 changed files with 726 additions and 550 deletions.
1 change: 0 additions & 1 deletion client/requestSessionRestoration.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ func (clt *Client) requestSessionRestoration(sessionKey []byte) (*webwire.Sessio
)
if err != nil {
// TODO: check for error types
fmt.Println("ERR", err)
return nil, fmt.Errorf("Session restoration request failed: %s", err)
}

Expand Down
21 changes: 12 additions & 9 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func NewServer(opts Options) *Server {
clients: make([]*Client, 0),
clientsLock: &sync.Mutex{},
sessionsEnabled: sessionsEnabled,
SessionRegistry: newSessionRegistry(),
SessionRegistry: newSessionRegistry(opts.MaxSessionConnections),

// Internals
warnLog: log.New(
Expand Down Expand Up @@ -172,14 +172,14 @@ func (srv *Server) handleSessionRestore(msg *Message) error {

key := string(msg.Payload.Data)

sessionExists := srv.SessionRegistry.Exists(key)

if sessionExists {
if srv.SessionRegistry.maxConns > 0 &&
srv.SessionRegistry.SessionConnections(key)+1 > srv.SessionRegistry.maxConns {
msg.fail(Error{
"SESSION_ACTIVE",
fmt.Sprintf(
"The session identified by key: '%s' is already active",
Code: "MAX_CONN_REACHED",
Message: fmt.Sprintf(
"Session %s reached the maximum number of concurrent connections (%d)",
key,
srv.SessionRegistry.maxConns,
),
})
return nil
Expand All @@ -190,6 +190,7 @@ func (srv *Server) handleSessionRestore(msg *Message) error {
msg.fail(Error{
"INTERNAL_ERROR",
fmt.Sprintf(
// TODO: whoops, that's some master-yoda-style english, fix it
"Session restoration request not could have been fulfilled",
),
})
Expand Down Expand Up @@ -218,7 +219,9 @@ func (srv *Server) handleSessionRestore(msg *Message) error {
}

msg.Client.Session = session
srv.SessionRegistry.register(msg.Client)
if okay := srv.SessionRegistry.register(msg.Client); !okay {
panic(fmt.Errorf("The number of concurrent session connections was unexpectedly exceeded"))
}

msg.fulfill(Payload{
Encoding: EncodingUtf8,
Expand Down Expand Up @@ -378,7 +381,7 @@ func (srv *Server) ServeHTTP(
_, message, err := conn.ReadMessage()
if err != nil {
if newClient.Session != nil {
// Mark session as inactive
// Decrement number of connections for this clients session
srv.SessionRegistry.deregister(newClient)
}

Expand Down
83 changes: 64 additions & 19 deletions sessionRegistry.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,46 +4,91 @@ import (
"sync"
)

// sessionRegistryEntry represents a session registry entry
type sessionRegistryEntry struct {
connections uint
client *Client
}

// sessionRegistry represents a thread safe registry of all currently active sessions
type sessionRegistry struct {
lock sync.RWMutex
registry map[string]*Client
maxConns uint
registry map[string]sessionRegistryEntry
}

// newSessionRegistry returns a new instance of a session registry
func newSessionRegistry() sessionRegistry {
// newSessionRegistry returns a new instance of a session registry.
// maxConns defines the maximum number of concurrent connections for a single session
// while zero stands for unlimited
func newSessionRegistry(maxConns uint) sessionRegistry {
return sessionRegistry{
lock: sync.RWMutex{},
registry: make(map[string]*Client),
maxConns: maxConns,
registry: make(map[string]sessionRegistryEntry),
}
}

// register registers the given clients session as a currently active session
func (asr *sessionRegistry) register(clt *Client) {
// register registers a new connection for the given clients session and returns true.
// Returns false if the given clients session already has the max number of connections assigned.
func (asr *sessionRegistry) register(clt *Client) bool {
asr.lock.Lock()
asr.registry[clt.Session.Key] = clt
asr.lock.Unlock()
defer asr.lock.Unlock()
if entry, exists := asr.registry[clt.Session.Key]; exists {
// Ensure max connections isn't exceeded
if asr.maxConns > 0 && entry.connections+1 > asr.maxConns {
return false
}
// Overwrite the current entry incrementing the number of connections
asr.registry[clt.Session.Key] = sessionRegistryEntry{
connections: entry.connections + 1,
client: entry.client,
}
return true
}
asr.registry[clt.Session.Key] = sessionRegistryEntry{
connections: 1,
client: clt,
}
return true
}

// deregister deregisters the given clients session from the list of currently active sessions
func (asr *sessionRegistry) deregister(clt *Client) {
// deregister decrements the number of connections assigned to the given clients session
// and returns true. If there's only one connection left then the session will be removed
// from the register and false will be returned
func (asr *sessionRegistry) deregister(clt *Client) bool {
asr.lock.Lock()
delete(asr.registry, clt.Session.Key)
asr.lock.Unlock()
defer asr.lock.Unlock()
if entry, exists := asr.registry[clt.Session.Key]; exists {
// If a single connection is left then remove the session
if entry.connections < 2 {
delete(asr.registry, clt.Session.Key)
return false
}
// Overwrite the current entry decrementing the number of connections
asr.registry[clt.Session.Key] = sessionRegistryEntry{
connections: entry.connections - 1,
client: entry.client,
}
}
return false
}

// Len returns the number of currently active sessions
func (asr *sessionRegistry) Len() int {
// ActiveSessions returns the number of currently active sessions
func (asr *sessionRegistry) ActiveSessions() int {
asr.lock.RLock()
len := len(asr.registry)
asr.lock.RUnlock()
return len
}

// Exists returns true if the session associated with the given key exists and is currently active
func (asr *sessionRegistry) Exists(sessionKey string) bool {
// SessionConnections returns the number of concurrent connections
// associated with the session associated with the given key.
// Returns zero if the session associated with the given key doesn't exist.
func (asr *sessionRegistry) SessionConnections(sessionKey string) uint {
asr.lock.RLock()
_, exists := asr.registry[sessionKey]
asr.lock.RUnlock()
return exists
defer asr.lock.RUnlock()
if sess, exists := asr.registry[sessionKey]; exists {
return sess.connections
}
return 0
}
9 changes: 5 additions & 4 deletions setupServer.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ import (

// Options represents the options for a headed server setup
type Options struct {
Addr string
Hooks Hooks
WarnLog io.Writer
ErrorLog io.Writer
Addr string
Hooks Hooks
MaxSessionConnections uint
WarnLog io.Writer
ErrorLog io.Writer
}

// SetDefaults sets default values to undefined options
Expand Down
54 changes: 28 additions & 26 deletions test/activeSessionRegistry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,38 @@ func TestActiveSessionRegistry(t *testing.T) {
// Initialize webwire server
srv, addr := setupServer(
t,
webwire.Hooks{
OnRequest: func(ctx context.Context) (webwire.Payload, error) {
// Extract request message and requesting client from the context
msg := ctx.Value(webwire.Msg).(webwire.Message)
webwire.Options{
Hooks: webwire.Hooks{
OnRequest: func(ctx context.Context) (webwire.Payload, error) {
// Extract request message and requesting client from the context
msg := ctx.Value(webwire.Msg).(webwire.Message)

// Close session on logout
if msg.Name == "logout" {
if err := msg.Client.CloseSession(); err != nil {
t.Errorf("Couldn't close session: %s", err)
// Close session on logout
if msg.Name == "logout" {
if err := msg.Client.CloseSession(); err != nil {
t.Errorf("Couldn't close session: %s", err)
}
return webwire.Payload{}, nil
}
return webwire.Payload{}, nil
}

// Try to create a new session
if err := msg.Client.CreateSession(nil); err != nil {
return webwire.Payload{}, webwire.Error{
Code: "INTERNAL_ERROR",
Message: fmt.Sprintf("Internal server error: %s", err),
// Try to create a new session
if err := msg.Client.CreateSession(nil); err != nil {
return webwire.Payload{}, webwire.Error{
Code: "INTERNAL_ERROR",
Message: fmt.Sprintf("Internal server error: %s", err),
}
}
}

// Return the key of the newly created session (use default binary encoding)
return webwire.Payload{
Data: []byte(msg.Client.Session.Key),
}, nil
// Return the key of the newly created session (use default binary encoding)
return webwire.Payload{
Data: []byte(msg.Client.Session.Key),
}, nil
},
// Define dummy hooks for sessions to be enabled on this server
OnSessionCreated: func(_ *webwire.Client) error { return nil },
OnSessionLookup: func(_ string) (*webwire.Session, error) { return nil, nil },
OnSessionClosed: func(_ *webwire.Client) error { return nil },
},
// Define dummy hooks for sessions to be enabled on this server
OnSessionCreated: func(_ *webwire.Client) error { return nil },
OnSessionLookup: func(_ string) (*webwire.Session, error) { return nil, nil },
OnSessionClosed: func(_ *webwire.Client) error { return nil },
},
)

Expand Down Expand Up @@ -75,7 +77,7 @@ func TestActiveSessionRegistry(t *testing.T) {
t.Fatalf("Request failed: %s", err)
}

activeSessionNumberBefore := srv.SessionRegistry.Len()
activeSessionNumberBefore := srv.SessionRegistry.ActiveSessions()
if activeSessionNumberBefore != 1 {
t.Fatalf(
"Unexpected active session number after authentication: %d",
Expand All @@ -94,7 +96,7 @@ func TestActiveSessionRegistry(t *testing.T) {
t.Fatalf("Request failed: %s", err)
}

activeSessionNumberAfter := srv.SessionRegistry.Len()
activeSessionNumberAfter := srv.SessionRegistry.ActiveSessions()
if activeSessionNumberAfter != 0 {
t.Fatalf("Unexpected active session number after logout: %d", activeSessionNumberAfter)
}
Expand Down
68 changes: 35 additions & 33 deletions test/authentication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,45 +65,47 @@ func TestAuthentication(t *testing.T) {
// Initialize webwire server
_, addr := setupServer(
t,
wwr.Hooks{
OnSignal: func(ctx context.Context) {
defer clientSignalReceived.Done()
// Extract request message and requesting client from the context
msg := ctx.Value(wwr.Msg).(wwr.Message)
compareSessions(t, createdSession, msg.Client.Session)
compareSessionInfo(msg.Client.Session)
},
OnRequest: func(ctx context.Context) (wwr.Payload, error) {
// Extract request message and requesting client from the context
msg := ctx.Value(wwr.Msg).(wwr.Message)

// If already authenticated then check session
if currentStep > 1 {
wwr.Options{
Hooks: wwr.Hooks{
OnSignal: func(ctx context.Context) {
defer clientSignalReceived.Done()
// Extract request message and requesting client from the context
msg := ctx.Value(wwr.Msg).(wwr.Message)
compareSessions(t, createdSession, msg.Client.Session)
compareSessionInfo(msg.Client.Session)
return expectedConfirmation, nil
}

// Try to create a new session
if err := msg.Client.CreateSession(sessionInfo); err != nil {
return wwr.Payload{}, wwr.Error{
Code: "INTERNAL_ERROR",
Message: fmt.Sprintf("Internal server error: %s", err),
},
OnRequest: func(ctx context.Context) (wwr.Payload, error) {
// Extract request message and requesting client from the context
msg := ctx.Value(wwr.Msg).(wwr.Message)

// If already authenticated then check session
if currentStep > 1 {
compareSessions(t, createdSession, msg.Client.Session)
compareSessionInfo(msg.Client.Session)
return expectedConfirmation, nil
}
}

// Authentication step is passed
currentStep = 2
// Try to create a new session
if err := msg.Client.CreateSession(sessionInfo); err != nil {
return wwr.Payload{}, wwr.Error{
Code: "INTERNAL_ERROR",
Message: fmt.Sprintf("Internal server error: %s", err),
}
}

// Return the key of the newly created session (use default binary encoding)
return wwr.Payload{
Data: []byte(msg.Client.Session.Key),
}, nil
// Authentication step is passed
currentStep = 2

// Return the key of the newly created session (use default binary encoding)
return wwr.Payload{
Data: []byte(msg.Client.Session.Key),
}, nil
},
// Define dummy hooks to enable sessions on this server
OnSessionCreated: func(_ *wwr.Client) error { return nil },
OnSessionLookup: func(_ string) (*wwr.Session, error) { return nil, nil },
OnSessionClosed: func(_ *wwr.Client) error { return nil },
},
// Define dummy hooks to enable sessions on this server
OnSessionCreated: func(_ *wwr.Client) error { return nil },
OnSessionLookup: func(_ string) (*wwr.Session, error) { return nil, nil },
OnSessionClosed: func(_ *wwr.Client) error { return nil },
},
)

Expand Down
Loading

1 comment on commit 144c6ad

@romshark
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commit fulfills this feature request.

Please sign in to comment.