Skip to content

Commit

Permalink
Added the ability to not specify the host port during a web socket ha…
Browse files Browse the repository at this point in the history
…ndshake

Signed-off-by: Yehor Beskhmelnytsyn <[email protected]>
  • Loading branch information
Besik13 committed Nov 24, 2020
1 parent d11139f commit fbf9544
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ public class MqttConnectOptions {
private boolean automaticReconnect = false;
private int maxReconnectDelay = 128000;
private Properties customWebSocketHeaders = null;
private boolean skipPortDuringHandshake = false;

// Client Operation Parameters
private int executorServiceTimeout = 1; // How long to wait in seconds when terminating the executor service.
Expand All @@ -99,7 +100,7 @@ public MqttConnectOptions() {

/**
* Returns the password to use for the connection.
*
*
* @return the password to use for the connection.
*/
public char[] getPassword() {
Expand All @@ -108,7 +109,7 @@ public char[] getPassword() {

/**
* Sets the password to use for the connection.
*
*
* @param password
* A Char Array of the password
*/
Expand All @@ -118,7 +119,7 @@ public void setPassword(char[] password) {

/**
* Returns the user name to use for the connection.
*
*
* @return the user name to use for the connection.
*/
public String getUserName() {
Expand All @@ -127,7 +128,7 @@ public String getUserName() {

/**
* Sets the user name to use for the connection.
*
*
* @param userName
* The Username as a String
*/
Expand All @@ -137,7 +138,7 @@ public void setUserName(String userName) {

/**
* Get the maximum time (in millis) to wait between reconnects
*
*
* @return Get the maximum time (in millis) to wait between reconnects
*/
public int getMaxReconnectDelay() {
Expand All @@ -146,7 +147,7 @@ public int getMaxReconnectDelay() {

/**
* Set the maximum time to wait between reconnects
*
*
* @param maxReconnectDelay
* the duration (in millis)
*/
Expand Down Expand Up @@ -206,7 +207,7 @@ private void validateWill(String dest, Object payload) {

/**
* Sets up the will information, based on the supplied parameters.
*
*
* @param topic
* the topic to send the LWT message to
* @param msg
Expand All @@ -227,7 +228,7 @@ protected void setWill(String topic, MqttMessage msg, int qos, boolean retained)

/**
* Returns the "keep alive" interval.
*
*
* @see #setKeepAliveInterval(int)
* @return the keep alive interval.
*/
Expand All @@ -237,7 +238,7 @@ public int getKeepAliveInterval() {

/**
* Returns the MQTT version.
*
*
* @see #setMqttVersion(int)
* @return the MQTT version.
*/
Expand Down Expand Up @@ -273,7 +274,7 @@ public void setKeepAliveInterval(int keepAliveInterval) throws IllegalArgumentEx
/**
* Returns the "max inflight". The max inflight limits to how many messages we
* can send without receiving acknowledgments.
*
*
* @see #setMaxInflight(int)
* @return the max inflight
*/
Expand All @@ -287,7 +288,7 @@ public int getMaxInflight() {
* <p>
* The default value is 10
* </p>
*
*
* @param maxInflight
* the number of maxInfligt messages
*/
Expand All @@ -300,7 +301,7 @@ public void setMaxInflight(int maxInflight) {

/**
* Returns the connection timeout value.
*
*
* @see #setConnectionTimeout(int)
* @return the connection timeout value.
*/
Expand All @@ -314,7 +315,7 @@ public int getConnectionTimeout() {
* the MQTT server to be established. The default timeout is 30 seconds. A value
* of 0 disables timeout processing meaning the client will wait until the
* network connection is made successfully or fails.
*
*
* @param connectionTimeout
* the timeout value, measured in seconds. It must be &gt;0;
*/
Expand All @@ -328,7 +329,7 @@ public void setConnectionTimeout(int connectionTimeout) {
/**
* Returns the socket factory that will be used when connecting, or
* <code>null</code> if one has not been set.
*
*
* @return The Socket Factory
*/
public SocketFactory getSocketFactory() {
Expand All @@ -340,7 +341,7 @@ public SocketFactory getSocketFactory() {
* apply its own policies around the creation of network sockets. If using an
* SSL connection, an <code>SSLSocketFactory</code> can be used to supply
* application-specific security settings.
*
*
* @param socketFactory
* the factory to use.
*/
Expand All @@ -350,7 +351,7 @@ public void setSocketFactory(SocketFactory socketFactory) {

/**
* Returns the topic to be used for last will and testament (LWT).
*
*
* @return the MqttTopic to use, or <code>null</code> if LWT is not set.
* @see #setWill(MqttTopic, byte[], int, boolean)
*/
Expand All @@ -362,7 +363,7 @@ public String getWillDestination() {
* Returns the message to be sent as last will and testament (LWT). The returned
* object is "read only". Calling any "setter" methods on the returned object
* will result in an <code>IllegalStateException</code> being thrown.
*
*
* @return the message to use, or <code>null</code> if LWT is not set.
*/
public MqttMessage getWillMessage() {
Expand All @@ -371,7 +372,7 @@ public MqttMessage getWillMessage() {

/**
* Returns the SSL properties for the connection.
*
*
* @return the properties for the SSL connection
*/
public Properties getSSLProperties() {
Expand Down Expand Up @@ -447,7 +448,7 @@ public Properties getSSLProperties() {
* object instead of using the default algorithm available in the platform.
* Example values: "PKIX" or "IBMJ9X509".</dd>
* </dl>
*
*
* @param props
* The SSL {@link Properties}
*/
Expand All @@ -465,7 +466,7 @@ public void setHttpsHostnameVerificationEnabled(boolean httpsHostnameVerificatio

/**
* Returns the HostnameVerifier for the SSL connection.
*
*
* @return the HostnameVerifier for the SSL connection
*/
public HostnameVerifier getSSLHostnameVerifier() {
Expand All @@ -479,7 +480,7 @@ public HostnameVerifier getSSLHostnameVerifier() {
* <p>
* There is no default HostnameVerifier
* </p>
*
*
* @param hostnameVerifier
* the {@link HostnameVerifier}
*/
Expand All @@ -490,7 +491,7 @@ public void setSSLHostnameVerifier(HostnameVerifier hostnameVerifier) {
/**
* Returns whether the client and server should remember state for the client
* across reconnects.
*
*
* @return the clean session flag
*/
public boolean isCleanSession() {
Expand All @@ -517,7 +518,7 @@ public boolean isCleanSession() {
* <li>The server will treat a subscription as non-durable
* </ul>
* </ul>
*
*
* @param cleanSession
* Set to True to enable cleanSession
*/
Expand All @@ -527,7 +528,7 @@ public void setCleanSession(boolean cleanSession) {

/**
* Return a list of serverURIs the client may connect to
*
*
* @return the serverURIs or null if not set
*/
public String[] getServerURIs() {
Expand Down Expand Up @@ -580,7 +581,7 @@ public String[] getServerURIs() {
* </p>
* </li>
* </ol>
*
*
* @param serverURIs
* to be used by the client
*/
Expand Down Expand Up @@ -615,7 +616,7 @@ public void setMqttVersion(int mqttVersion) throws IllegalArgumentException {
/**
* Returns whether the client will automatically attempt to reconnect to the
* server if the connection is lost
*
*
* @return the automatic reconnection flag.
*/
public boolean isAutomaticReconnect() {
Expand All @@ -634,14 +635,14 @@ public boolean isAutomaticReconnect() {
* double until it is at 2 minutes at which point the delay will stay at 2
* minutes.</li>
* </ul>
*
*
* @param automaticReconnect
* If set to True, Automatic Reconnect will be enabled
*/
public void setAutomaticReconnect(boolean automaticReconnect) {
this.automaticReconnect = automaticReconnect;
}

public int getExecutorServiceTimeout() {
return executorServiceTimeout;
}
Expand All @@ -650,13 +651,31 @@ public int getExecutorServiceTimeout() {
* Set the time in seconds that the executor service should wait when
* terminating before forcefully terminating. It is not recommended to change
* this value unless you are absolutely sure that you need to.
*
*
* @param executorServiceTimeout the time in seconds to wait when shutting down.Ï
*/
public void setExecutorServiceTimeout(int executorServiceTimeout) {
this.executorServiceTimeout = executorServiceTimeout;
}

/**
* Returns whether to skip a port during a handshake
*
* @return skipPortDuringHandshake
*/
public boolean isSkipPortDuringHandshake() {
return skipPortDuringHandshake;
}

/**
* Sets a flag that indicates whether to add a port to the host during a handshake
*
* @param skip if set to True, the port will not be added
*/
public void setSkipPortDuringHandshake(boolean skip) {
this.skipPortDuringHandshake = skip;
}

/**
* @return The Debug Properties
*/
Expand All @@ -679,6 +698,7 @@ public Properties getDebug() {
} else {
p.put("SSLProperties", getSSLProperties());
}
p.put("SkipPortDuringHandshake", isSkipPortDuringHandshake());
return p;
}

Expand All @@ -700,5 +720,5 @@ public Properties getCustomWebSocketHeaders() {
public String toString() {
return Debug.dumpProperties(getDebug(), "Connection options");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,23 @@ public class WebSocketHandshake {
private static final String HTTP_HEADER_CONNECTION_VALUE = "upgrade";
private static final String HTTP_HEADER_SEC_WEBSOCKET_PROTOCOL = "sec-websocket-protocol";

private final boolean skipPortDuringHandshake;

InputStream input;
OutputStream output;
String uri;
String host;
int port;
Properties customWebSocketHeaders;

public WebSocketHandshake(InputStream input, OutputStream output, String uri, String host, int port, Properties customWebSocketHeaders){
public WebSocketHandshake(InputStream input, OutputStream output, String uri, String host, int port, Properties customWebSocketHeaders, boolean skipPortDuringHandshake){
this.input = input;
this.output = output;
this.uri = uri;
this.host = host;
this.port = port;
this.customWebSocketHeaders = customWebSocketHeaders;
this.skipPortDuringHandshake = skipPortDuringHandshake;
}


Expand Down Expand Up @@ -99,7 +102,7 @@ private void sendHandshakeRequest(String key) throws IOException{

PrintWriter pw = new PrintWriter(output);
pw.print("GET " + path + " HTTP/1.1" + LINE_SEPARATOR);
if (port != 80) {
if (port != 80 && !skipPortDuringHandshake) {
pw.print("Host: " + host + ":" + port + LINE_SEPARATOR);
}
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public class WebSocketNetworkModule extends TCPNetworkModule {
private Properties customWebsocketHeaders;
private PipedInputStream pipedInputStream;
private WebSocketReceiver webSocketReceiver;
private final boolean skipPortDuringHandshake;
ByteBuffer recievedPayload;

/**
Expand All @@ -50,20 +51,20 @@ public class WebSocketNetworkModule extends TCPNetworkModule {
*/
private ByteArrayOutputStream outputStream = new ExtendedByteArrayOutputStream(this);

public WebSocketNetworkModule(SocketFactory factory, String uri, String host, int port, String resourceContext, Properties customWebsocketHeaders){
public WebSocketNetworkModule(SocketFactory factory, String uri, String host, int port, String resourceContext, Properties customWebsocketHeaders, boolean skipPortDuringHandshake){
super(factory, host, port, resourceContext);
this.uri = uri;
this.host = host;
this.port = port;
this.customWebsocketHeaders = customWebsocketHeaders;
this.pipedInputStream = new PipedInputStream();

this.skipPortDuringHandshake = skipPortDuringHandshake;
log.setResourceName(resourceContext);
}

public void start() throws IOException, MqttException {
super.start();
WebSocketHandshake handshake = new WebSocketHandshake(getSocketInputStream(), getSocketOutputStream(), uri, host, port, customWebsocketHeaders);
WebSocketHandshake handshake = new WebSocketHandshake(getSocketInputStream(), getSocketOutputStream(), uri, host, port, customWebsocketHeaders, skipPortDuringHandshake);
handshake.execute();
this.webSocketReceiver = new WebSocketReceiver(getSocketInputStream(), pipedInputStream);
webSocketReceiver.start("webSocketReceiver");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public NetworkModule createNetworkModule(URI brokerUri, MqttConnectOptions optio
throw ExceptionHelper.createMqttException(MqttException.REASON_CODE_SOCKET_FACTORY_MISMATCH);
}
WebSocketNetworkModule netModule = new WebSocketNetworkModule(factory, brokerUri.toString(), host, port,
clientId, options.getCustomWebSocketHeaders());
clientId, options.getCustomWebSocketHeaders(), options.isSkipPortDuringHandshake());
netModule.setConnectTimeout(options.getConnectionTimeout());
return netModule;
}
Expand Down
Loading

0 comments on commit fbf9544

Please sign in to comment.