Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to process STOMP headers #15

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.github.mthizo247.cloud.netflix.zuul.web.authentication.stomp;

import com.github.mthizo247.cloud.netflix.zuul.web.socket.WebSocketMessageAccessor;
import com.github.mthizo247.cloud.netflix.zuul.web.socket.WebSocketStompHeadersCallback;
import org.springframework.http.HttpHeaders;
import org.springframework.messaging.simp.stomp.StompHeaders;
import org.springframework.web.socket.WebSocketSession;

/**
* @author Yurii Vlasiuk
* @version 1.0
* @since 30.05.2018
*/
public class AuthorizationStompHeadersCallback implements WebSocketStompHeadersCallback {

@Override
public void applyHeaders(WebSocketSession userAgentSession, WebSocketMessageAccessor accessor, StompHeaders stompHeaders) {
String authorizationHeader = accessor.getHeader(HttpHeaders.AUTHORIZATION);
if ((authorizationHeader != null) && (!authorizationHeader.isEmpty())) {
stompHeaders.add(HttpHeaders.AUTHORIZATION, authorizationHeader);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.github.mthizo247.cloud.netflix.zuul.web.authentication.stomp;

import com.github.mthizo247.cloud.netflix.zuul.web.socket.WebSocketMessageAccessor;
import com.github.mthizo247.cloud.netflix.zuul.web.socket.WebSocketStompHeadersCallback;
import org.springframework.messaging.simp.stomp.StompHeaders;
import org.springframework.web.socket.WebSocketSession;

import java.util.List;

/**
* @author Yurii Vlasiuk
* @version 1.0
* @since 30.05.2018
*/
public class CompositeStompHeadersCallback implements WebSocketStompHeadersCallback {

private List<WebSocketStompHeadersCallback> stompHeadersCallbacks;

public CompositeStompHeadersCallback(final List<WebSocketStompHeadersCallback> stompHeadersCallbacks) {
this.stompHeadersCallbacks = stompHeadersCallbacks;
}

@Override
public void applyHeaders(WebSocketSession userAgentSession, WebSocketMessageAccessor accessor, StompHeaders stompHeaders) {
if ((stompHeadersCallbacks != null) && (!stompHeadersCallbacks.isEmpty())) {
for(WebSocketStompHeadersCallback callback: stompHeadersCallbacks) {
callback.applyHeaders(userAgentSession, accessor, stompHeaders);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,25 @@ public class ProxyWebSocketConnectionManager extends ConnectionManagerSupport
private final WebSocketStompClient stompClient;
private final WebSocketSession userAgentSession;
private final WebSocketHttpHeadersCallback httpHeadersCallback;
private final WebSocketStompHeadersCallback webSocketStompHeadersCallback;
private StompSession serverSession;
private Map<String, StompSession.Subscription> subscriptions = new ConcurrentHashMap<>();
private ErrorHandler errorHandler;
private SimpMessagingTemplate messagingTemplate;
private WebSocketMessageAccessor accessor;

public ProxyWebSocketConnectionManager(SimpMessagingTemplate messagingTemplate,
WebSocketStompClient stompClient, WebSocketSession userAgentSession,
WebSocketHttpHeadersCallback httpHeadersCallback, String uri) {
WebSocketHttpHeadersCallback httpHeadersCallback, String uri,
WebSocketStompHeadersCallback webSocketStompHeadersCallback,
WebSocketMessageAccessor accessor) {
super(uri);
this.messagingTemplate = messagingTemplate;
this.stompClient = stompClient;
this.userAgentSession = userAgentSession;
this.httpHeadersCallback = httpHeadersCallback;
this.webSocketStompHeadersCallback = webSocketStompHeadersCallback;
this.accessor = accessor;
}

public void errorHandler(ErrorHandler errorHandler) {
Expand All @@ -74,6 +80,14 @@ private WebSocketHttpHeaders buildWebSocketHttpHeaders() {
return wsHeaders;
}

private StompHeaders buildWebSocketStompHeaders() {
StompHeaders stompHeaders = new StompHeaders();
if (webSocketStompHeadersCallback != null) {
webSocketStompHeadersCallback.applyHeaders(userAgentSession, accessor, stompHeaders);
}
return stompHeaders;
}

@Override
protected void openConnection() {
connect();
Expand All @@ -82,7 +96,7 @@ protected void openConnection() {
public void connect() {
try {
serverSession = stompClient
.connect(getUri().toString(), buildWebSocketHttpHeaders(), this)
.connect(getUri(), buildWebSocketHttpHeaders(), buildWebSocketStompHeaders(), this)
.get();
} catch (Exception e) {
logger.error("Error connecting to web socket uri " + getUri(), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
public class ProxyWebSocketHandler extends WebSocketHandlerDecorator {
private final Logger logger = LoggerFactory.getLogger(ProxyWebSocketHandler.class);
private final WebSocketHttpHeadersCallback headersCallback;
private final WebSocketStompHeadersCallback webSocketStompHeadersCallback;
private final SimpMessagingTemplate messagingTemplate;
private final ProxyTargetResolver proxyTargetResolver;
private final ZuulWebSocketProperties zuulWebSocketProperties;
Expand All @@ -56,6 +57,7 @@ public class ProxyWebSocketHandler extends WebSocketHandlerDecorator {
public ProxyWebSocketHandler(WebSocketHandler delegate,
WebSocketStompClient stompClient,
WebSocketHttpHeadersCallback headersCallback,
WebSocketStompHeadersCallback webSocketStompHeadersCallback,
SimpMessagingTemplate messagingTemplate,
ProxyTargetResolver proxyTargetResolver,
ZuulWebSocketProperties zuulWebSocketProperties) {
Expand All @@ -65,6 +67,7 @@ public ProxyWebSocketHandler(WebSocketHandler delegate,
this.messagingTemplate = messagingTemplate;
this.proxyTargetResolver = proxyTargetResolver;
this.zuulWebSocketProperties = zuulWebSocketProperties;
this.webSocketStompHeadersCallback = webSocketStompHeadersCallback;
}

public void errorHandler(ErrorHandler errorHandler) {
Expand Down Expand Up @@ -148,7 +151,7 @@ private void handleMessageFromClient(WebSocketSession session,

if (StompCommand.CONNECT.toString().equalsIgnoreCase(accessor.getCommand())) {
handled = true;
connectToProxiedTarget(session);
connectToProxiedTarget(session, accessor);
}

if (!handled) {
Expand All @@ -159,7 +162,7 @@ private void handleMessageFromClient(WebSocketSession session,
}
}

private void connectToProxiedTarget(WebSocketSession session) {
private void connectToProxiedTarget(WebSocketSession session, WebSocketMessageAccessor accessor) {
URI sessionUri = session.getUri();
ZuulWebSocketProperties.WsBrokerage wsBrokerage = getWebSocketBrokarage(
sessionUri);
Expand All @@ -180,7 +183,7 @@ private void connectToProxiedTarget(WebSocketSession session) {
.toUriString();

ProxyWebSocketConnectionManager connectionManager = new ProxyWebSocketConnectionManager(
messagingTemplate, stompClient, session, headersCallback, uri);
messagingTemplate, stompClient, session, headersCallback, uri, webSocketStompHeadersCallback, accessor);
connectionManager.errorHandler(this.errorHandler);
managers.put(session, connectionManager);
connectionManager.start();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.github.mthizo247.cloud.netflix.zuul.web.socket;

import org.springframework.messaging.simp.stomp.StompHeaders;
import org.springframework.web.socket.WebSocketSession;

/**
* @author Yurii Vlasiuk
* @version 1.0
* @since 30.05.2018
*/
public interface WebSocketStompHeadersCallback {
void applyHeaders(WebSocketSession userAgentSession, WebSocketMessageAccessor accessor, StompHeaders stompHeaders);
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import com.github.mthizo247.cloud.netflix.zuul.web.authentication.CompositeHeadersCallback;
import com.github.mthizo247.cloud.netflix.zuul.web.authentication.LoginCookieHeadersCallback;
import com.github.mthizo247.cloud.netflix.zuul.web.authentication.OAuth2BearerPrincipalHeadersCallback;
import com.github.mthizo247.cloud.netflix.zuul.web.authentication.stomp.AuthorizationStompHeadersCallback;
import com.github.mthizo247.cloud.netflix.zuul.web.authentication.stomp.CompositeStompHeadersCallback;
import com.github.mthizo247.cloud.netflix.zuul.web.filter.ProxyRedirectFilter;
import com.github.mthizo247.cloud.netflix.zuul.web.proxytarget.CompositeProxyTargetResolver;
import com.github.mthizo247.cloud.netflix.zuul.web.proxytarget.EurekaProxyTargetResolver;
Expand Down Expand Up @@ -100,6 +102,10 @@ public class ZuulWebSocketConfiguration extends AbstractWebSocketMessageBrokerCo
@Autowired
@Qualifier("compositeHeadersCallback")
WebSocketHttpHeadersCallback webSocketHttpHeadersCallback;
@Autowired
@Qualifier("compositeStompHeadersCallback")
WebSocketStompHeadersCallback webSocketStompHeadersCallback;


@Override
public void registerStompEndpoints(StompEndpointRegistry registry) {
Expand Down Expand Up @@ -159,6 +165,7 @@ public void configureWebSocketTransport(WebSocketTransportRegistration registrat
public WebSocketHandler decorate(WebSocketHandler handler) {
ProxyWebSocketHandler proxyWebSocketHandler = new ProxyWebSocketHandler(
handler, stompClient, webSocketHttpHeadersCallback,
webSocketStompHeadersCallback,
messagingTemplate,
proxyTargetResolver,
zuulWebSocketProperties);
Expand Down Expand Up @@ -191,6 +198,17 @@ public WebSocketHttpHeadersCallback loginCookieHeadersCallback() {
return new LoginCookieHeadersCallback();
}

@Bean
@Primary
public WebSocketStompHeadersCallback compositeStompHeadersCallback(final List<WebSocketStompHeadersCallback> callbacks) {
return new CompositeStompHeadersCallback(callbacks);
}

@Bean
public WebSocketStompHeadersCallback authorizationStompHeadersCallback() {
return new AuthorizationStompHeadersCallback();
}

@Bean
public ProxyTargetResolver urlProxyTargetResolver(
final ZuulProperties zuulProperties) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.messaging.WebSocketStompClient;

import java.net.URI;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
Expand All @@ -49,17 +51,19 @@ public class ProxyWebSocketConnectionManagerTests {
private ListenableFuture<StompSession> listenableFuture = (ListenableFuture<StompSession>) mock(
ListenableFuture.class);
private ErrorHandler errHandler = mock(ErrorHandler.class);
private WebSocketMessageAccessor messageAccessor = WebSocketMessageAccessor.create("example");
private WebSocketStompHeadersCallback stompHeadersCallback = mock(WebSocketStompHeadersCallback.class);

@Before
public void init() throws Exception {
String uri = "http://example.com";
URI uri = URI.create("http://example.com");
proxyConnectionManager = new ProxyWebSocketConnectionManager(messagingTemplate,
stompClient, wsSession, headersCallback, uri);
stompClient, wsSession, headersCallback, uri.toString(), stompHeadersCallback, messageAccessor);

proxyConnectionManager.errorHandler(errHandler);

when(listenableFuture.get()).thenReturn(serverSession);
when(stompClient.connect(uri, new WebSocketHttpHeaders(),
when(stompClient.connect(uri, new WebSocketHttpHeaders(), new StompHeaders(),
proxyConnectionManager)).thenReturn(listenableFuture);
}

Expand Down