diff --git a/src/main/java/com/github/nkzawa/engineio/client/Socket.java b/src/main/java/com/github/nkzawa/engineio/client/Socket.java index 7372169ea88..2fb57df3f91 100644 --- a/src/main/java/com/github/nkzawa/engineio/client/Socket.java +++ b/src/main/java/com/github/nkzawa/engineio/client/Socket.java @@ -96,6 +96,8 @@ public void run() {} public static boolean priorWebsocketSuccess = false; + private static SSLContext defaultSSLContext; + private boolean secure; private boolean upgrade; private boolean timestampRequests; @@ -123,6 +125,9 @@ public void run() {} private ReadyState readyState; private ScheduledExecutorService heartbeatScheduler = Executors.newSingleThreadScheduledExecutor(); + public static void setDefaultSSLContext(SSLContext sslContext) { + defaultSSLContext = sslContext; + } public Socket() { this(new Options()); @@ -167,7 +172,7 @@ public Socket(Options opts) { } this.secure = opts.secure; - this.sslContext = opts.sslContext; + this.sslContext = opts.sslContext != null ? opts.sslContext : defaultSSLContext; this.hostname = opts.hostname != null ? opts.hostname : "localhost"; this.port = opts.port != 0 ? opts.port : (this.secure ? 443 : 80); this.query = opts.query != null ? diff --git a/src/test/java/com/github/nkzawa/engineio/client/SSLConnectionTest.java b/src/test/java/com/github/nkzawa/engineio/client/SSLConnectionTest.java index 5685f7dff29..c07255e0b53 100644 --- a/src/test/java/com/github/nkzawa/engineio/client/SSLConnectionTest.java +++ b/src/test/java/com/github/nkzawa/engineio/client/SSLConnectionTest.java @@ -1,6 +1,7 @@ package com.github.nkzawa.engineio.client; import com.github.nkzawa.emitter.Emitter; +import org.junit.After; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -33,6 +34,11 @@ public boolean verify(String hostname, javax.net.ssl.SSLSession sslSession) { private Socket socket; + @After + public void tearDown() { + Socket.setDefaultSSLContext(null); + } + @Override Socket.Options createOptions() { Socket.Options opts = super.createOptions(); @@ -74,17 +80,12 @@ public void call(Object... args) { socket.on(Socket.EVENT_MESSAGE, new Emitter.Listener() { @Override public void call(Object... args) { - assertThat((String)args[0], is("hi")); + assertThat((String) args[0], is("hi")); socket.close(); latch.countDown(); } }); } - }).on("error", new Emitter.Listener() { - @Override - public void call(Object... args) { - ((Exception)args[0]).printStackTrace(); - } }); socket.open(); latch.await(); @@ -119,4 +120,27 @@ public void call(Object... args) { socket.open(); latch.await(); } + + @Test(timeout = TIMEOUT) + public void defaultSSLContext() throws Exception { + final CountDownLatch latch = new CountDownLatch(1); + + Socket.setDefaultSSLContext(createSSLContext()); + socket = new Socket(createOptions()); + socket.on(Socket.EVENT_OPEN, new Emitter.Listener() { + @Override + public void call(Object... args) { + socket.on(Socket.EVENT_MESSAGE, new Emitter.Listener() { + @Override + public void call(Object... args) { + assertThat((String) args[0], is("hi")); + socket.close(); + latch.countDown(); + } + }); + } + }); + socket.open(); + latch.await(); + } }