Skip to content

Commit

Permalink
Add support for http forward proxy with CONNECT
Browse files Browse the repository at this point in the history
This is a squash and modification of master commits that also includes:
netty,okhttp: Fix CONNECT and its error handling

This commit has been modified to reduce its size to substantially reduce
risk of it breaking Netty error handling. But that also means proxy
error handling just provides a useless "there was an error" sort of
message.

There is no Java API to enable the proxy support. Instead, you must set
the GRPC_PROXY_EXP environment variable which should be set to a
host:port string. The environment variable is temporary; it will not
exist in future releases. It exists to provide support without needing
explicit code to enable the future, while at the same time not risking
enabling it for existing users.
  • Loading branch information
ejona86 committed Jan 27, 2017
1 parent 5bfac21 commit 23f5a6f
Show file tree
Hide file tree
Showing 8 changed files with 489 additions and 12 deletions.
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ subprojects {

netty: 'io.netty:netty-codec-http2:[4.1.7.Final]',
netty_epoll: 'io.netty:netty-transport-native-epoll:4.1.7.Final' + epoll_suffix,
netty_proxy_handler: 'io.netty:netty-handler-proxy:4.1.7.Final',
netty_tcnative: 'io.netty:netty-tcnative-boringssl-static:1.1.33.Fork25',

// Test dependencies.
Expand Down
3 changes: 2 additions & 1 deletion netty/build.gradle
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
description = "gRPC: Netty"
dependencies {
compile project(':grpc-core'),
libraries.netty
libraries.netty,
libraries.netty_proxy_handler

// Tests depend on base class defined by core module.
testCompile project(':grpc-core').sourceSets.test.output,
Expand Down
19 changes: 19 additions & 0 deletions netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,25 @@ static ProtocolNegotiator createProtocolNegotiator(
String authority,
NegotiationType negotiationType,
SslContext sslContext) {
ProtocolNegotiator negotiator =
createProtocolNegotiatorByType(authority, negotiationType, sslContext);
String proxy = System.getenv("GRPC_PROXY_EXP");
if (proxy != null) {
String[] parts = proxy.split(":", 2);
int port = 80;
if (parts.length > 1) {
port = Integer.parseInt(parts[1]);
}
InetSocketAddress proxyAddress = new InetSocketAddress(parts[0], port);
negotiator = ProtocolNegotiators.httpProxy(proxyAddress, null, null, negotiator);
}
return negotiator;
}

private static ProtocolNegotiator createProtocolNegotiatorByType(
String authority,
NegotiationType negotiationType,
SslContext sslContext) {
switch (negotiationType) {
case PLAINTEXT:
return ProtocolNegotiators.plaintext();
Expand Down
93 changes: 90 additions & 3 deletions netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerAdapter;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandler;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
Expand All @@ -54,13 +55,17 @@
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http2.Http2ClientUpgradeCodec;
import io.netty.handler.proxy.HttpProxyHandler;
import io.netty.handler.proxy.ProxyConnectionEvent;
import io.netty.handler.proxy.ProxyHandler;
import io.netty.handler.ssl.OpenSsl;
import io.netty.handler.ssl.OpenSslEngine;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import io.netty.util.AsciiString;
import io.netty.util.ReferenceCountUtil;
import java.net.SocketAddress;
import java.net.URI;
import java.util.ArrayDeque;
import java.util.Arrays;
Expand Down Expand Up @@ -189,6 +194,73 @@ public AsciiString scheme() {
}
}

/**
* Returns a {@link ProtocolNegotiator} that does HTTP CONNECT proxy negotiation.
*/
public static ProtocolNegotiator httpProxy(final SocketAddress proxyAddress,
final @Nullable String proxyUsername, final @Nullable String proxyPassword,
final ProtocolNegotiator negotiator) {
Preconditions.checkNotNull(proxyAddress, "proxyAddress");
Preconditions.checkNotNull(negotiator, "negotiator");
class ProxyNegotiator implements ProtocolNegotiator {
@Override
public Handler newHandler(GrpcHttp2ConnectionHandler http2Handler) {
HttpProxyHandler proxyHandler;
if (proxyUsername == null || proxyPassword == null) {
proxyHandler = new HttpProxyHandler(proxyAddress);
} else {
proxyHandler = new HttpProxyHandler(proxyAddress, proxyUsername, proxyPassword);
}
return new BufferUntilProxyTunnelledHandler(
proxyHandler, negotiator.newHandler(http2Handler));
}
}

return new ProxyNegotiator();
}

/**
* Buffers all writes until the HTTP CONNECT tunnel is established.
*/
static final class BufferUntilProxyTunnelledHandler extends AbstractBufferingHandler
implements ProtocolNegotiator.Handler {
private final ProtocolNegotiator.Handler originalHandler;

public BufferUntilProxyTunnelledHandler(
ProxyHandler proxyHandler, ProtocolNegotiator.Handler handler) {
super(proxyHandler, handler);
this.originalHandler = handler;
}


@Override
public AsciiString scheme() {
return originalHandler.scheme();
}

@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof ProxyConnectionEvent) {
writeBufferedAndRemove(ctx);
}
super.userEventTriggered(ctx, evt);
}

@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
fail(ctx, unavailableException("Connection broken while trying to CONNECT through proxy"));
super.channelInactive(ctx);
}

@Override
public void close(ChannelHandlerContext ctx, ChannelPromise future) throws Exception {
if (ctx.channel().isActive()) { // This may be a notification that the socket was closed
fail(ctx, unavailableException("Channel closed while trying to CONNECT through proxy"));
}
super.close(ctx, future);
}
}

/**
* Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will
* be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel}
Expand Down Expand Up @@ -366,10 +438,22 @@ public void channelRegistered(ChannelHandlerContext ctx) throws Exception {
* lifetime and we only want to configure it once.
*/
if (handlers != null) {
ctx.pipeline().addFirst(handlers);
for (ChannelHandler handler : handlers) {
ctx.pipeline().addBefore(ctx.name(), null, handler);
}
ChannelHandler handler0 = handlers[0];
ChannelHandlerContext handler0Ctx = ctx.pipeline().context(handlers[0]);
handlers = null;
if (handler0Ctx != null) { // The handler may have removed itself immediately
if (handler0 instanceof ChannelInboundHandler) {
((ChannelInboundHandler) handler0).channelRegistered(handler0Ctx);
} else {
handler0Ctx.fireChannelRegistered();
}
}
} else {
super.channelRegistered(ctx);
}
super.channelRegistered(ctx);
}

@Override
Expand Down Expand Up @@ -424,7 +508,10 @@ public void flush(ChannelHandlerContext ctx) {

@Override
public void close(ChannelHandlerContext ctx, ChannelPromise future) throws Exception {
fail(ctx, unavailableException("Channel closed while performing protocol negotiation"));
if (ctx.channel().isActive()) { // This may be a notification that the socket was closed
fail(ctx, unavailableException("Channel closed while performing protocol negotiation"));
}
super.close(ctx, future);
}

protected final void fail(ChannelHandlerContext ctx, Throwable cause) {
Expand Down
115 changes: 113 additions & 2 deletions netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,41 @@

package io.grpc.netty;

import static com.google.common.base.Charsets.UTF_8;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;

import io.grpc.netty.ProtocolNegotiators.ServerTlsHandler;
import io.grpc.netty.ProtocolNegotiators.TlsNegotiator;
import io.grpc.testing.TestUtils;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandler;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import io.netty.handler.ssl.SupportedCipherSuiteFilter;
import java.io.File;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.logging.Filter;
import java.util.logging.Level;
import java.util.logging.LogRecord;
Expand All @@ -63,10 +79,17 @@
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;

@RunWith(JUnit4.class)
public class ProtocolNegotiatorsTest {
@Rule public final ExpectedException thrown = ExpectedException.none();
private static final Runnable NOOP_RUNNABLE = new Runnable() {
@Override public void run() {}
};

@Rule
public final ExpectedException thrown = ExpectedException.none();

private GrpcHttp2ConnectionHandler grpcHandler = mock(GrpcHttp2ConnectionHandler.class);

Expand All @@ -81,7 +104,7 @@ public void setUp() throws Exception {
File serverCert = TestUtils.loadCert("server1.pem");
File key = TestUtils.loadCert("server1.key");
sslContext = GrpcSslContexts.forServer(serverCert, key)
.ciphers(TestUtils.preferredTestCiphers(), SupportedCipherSuiteFilter.INSTANCE).build();
.ciphers(TestUtils.preferredTestCiphers(), SupportedCipherSuiteFilter.INSTANCE).build();
engine = SSLContext.getDefault().createSSLEngine();
}

Expand Down Expand Up @@ -272,4 +295,92 @@ public void tls_invalidHost() throws SSLException {
assertEquals("bad_host:1234", negotiator.getHost());
assertEquals(-1, negotiator.getPort());
}

@Test
public void httpProxy_nullAddressNpe() throws Exception {
thrown.expect(NullPointerException.class);
ProtocolNegotiators.httpProxy(null, "user", "pass", ProtocolNegotiators.plaintext());
}

@Test
public void httpProxy_nullNegotiatorNpe() throws Exception {
thrown.expect(NullPointerException.class);
ProtocolNegotiators.httpProxy(
InetSocketAddress.createUnresolved("localhost", 80), "user", "pass", null);
}

@Test
public void httpProxy_nullUserPassNoException() throws Exception {
assertNotNull(ProtocolNegotiators.httpProxy(
InetSocketAddress.createUnresolved("localhost", 80), null, null,
ProtocolNegotiators.plaintext()));
}

@Test(timeout = 5000)
public void httpProxy_completes() throws Exception {
DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
// ProxyHandler is incompatible with EmbeddedChannel because when channelRegistered() is called
// the channel is already active.
LocalAddress proxy = new LocalAddress("httpProxy_completes");
SocketAddress host = InetSocketAddress.createUnresolved("specialHost", 314);

ChannelInboundHandler mockHandler = mock(ChannelInboundHandler.class);
Channel serverChannel = new ServerBootstrap().group(elg).channel(LocalServerChannel.class)
.childHandler(mockHandler)
.bind(proxy).sync().channel();

ProtocolNegotiator nego =
ProtocolNegotiators.httpProxy(proxy, null, null, ProtocolNegotiators.plaintext());
ChannelHandler handler = nego.newHandler(grpcHandler);
Channel channel = new Bootstrap().group(elg).channel(LocalChannel.class).handler(handler)
.register().sync().channel();
pipeline = channel.pipeline();
// Wait for initialization to complete
channel.eventLoop().submit(NOOP_RUNNABLE).sync();
// The grpcHandler must be in the pipeline, but we don't actually want it during our test
// because it will consume all events since it is a mock. We only use it because it is required
// to construct the Handler.
pipeline.remove(grpcHandler);
channel.connect(host).sync();
serverChannel.close();
ArgumentCaptor<ChannelHandlerContext> contextCaptor =
ArgumentCaptor.forClass(ChannelHandlerContext.class);
Mockito.verify(mockHandler).channelActive(contextCaptor.capture());
ChannelHandlerContext serverContext = contextCaptor.getValue();

final String golden = "isThisThingOn?";
ChannelFuture negotiationFuture = channel.writeAndFlush(bb(golden, channel));

// Wait for sending initial request to complete
channel.eventLoop().submit(NOOP_RUNNABLE).sync();
ArgumentCaptor<Object> objectCaptor = ArgumentCaptor.forClass(Object.class);
Mockito.verify(mockHandler)
.channelRead(any(ChannelHandlerContext.class), objectCaptor.capture());
ByteBuf b = (ByteBuf) objectCaptor.getValue();
String request = b.toString(UTF_8);
b.release();
assertTrue("No trailing newline: " + request, request.endsWith("\r\n\r\n"));
assertTrue("No CONNECT: " + request, request.startsWith("CONNECT specialHost:314 "));
assertTrue("No host header: " + request, request.contains("host: specialHost:314"));

assertFalse(negotiationFuture.isDone());
serverContext.writeAndFlush(bb("HTTP/1.1 200 OK\r\n\r\n", serverContext.channel())).sync();
negotiationFuture.sync();

channel.eventLoop().submit(NOOP_RUNNABLE).sync();
objectCaptor.getAllValues().clear();
Mockito.verify(mockHandler, times(2))
.channelRead(any(ChannelHandlerContext.class), objectCaptor.capture());
b = (ByteBuf) objectCaptor.getAllValues().get(1);
// If we were using the real grpcHandler, this would have been the HTTP/2 preface
String preface = b.toString(UTF_8);
b.release();
assertEquals(golden, preface);

channel.close();
}

private static ByteBuf bb(String s, Channel c) {
return ByteBufUtil.writeUtf8(c.alloc(), s);
}
}
13 changes: 12 additions & 1 deletion okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -311,9 +311,20 @@ public ConnectionClientTransport newClientTransport(
if (closed) {
throw new IllegalStateException("The transport factory is closed.");
}
InetSocketAddress proxyAddress = null;
String proxy = System.getenv("GRPC_PROXY_EXP");
if (proxy != null) {
String[] parts = proxy.split(":", 2);
int port = 80;
if (parts.length > 1) {
port = Integer.parseInt(parts[1]);
}
proxyAddress = new InetSocketAddress(parts[0], port);
}
InetSocketAddress inetSocketAddr = (InetSocketAddress) addr;
OkHttpClientTransport transport = new OkHttpClientTransport(inetSocketAddr, authority,
userAgent, executor, socketFactory, Utils.convertSpec(connectionSpec), maxMessageSize);
userAgent, executor, socketFactory, Utils.convertSpec(connectionSpec), maxMessageSize,
proxyAddress, null, null);
if (enableKeepAlive) {
transport.enableKeepAlive(true, keepAliveDelayNanos, keepAliveTimeoutNanos);
}
Expand Down
Loading

0 comments on commit 23f5a6f

Please sign in to comment.