diff --git a/core/src/main/java/io/grpc/internal/AbstractManagedChannelImplBuilder.java b/core/src/main/java/io/grpc/internal/AbstractManagedChannelImplBuilder.java index 5f4b4431851..35e19503618 100644 --- a/core/src/main/java/io/grpc/internal/AbstractManagedChannelImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/AbstractManagedChannelImplBuilder.java @@ -297,10 +297,6 @@ public void setEnableTracing(boolean enabled) { @Override public ManagedChannel build() { ClientTransportFactory transportFactory = buildTransportFactory(); - if (authorityOverride != null) { - transportFactory = new AuthorityOverridingTransportFactory( - transportFactory, authorityOverride); - } NameResolver.Factory nameResolverFactory = this.nameResolverFactory; if (nameResolverFactory == null) { // Avoid loading the provider unless necessary, as a way to workaround a possibly-costly @@ -308,6 +304,10 @@ public ManagedChannel build() { // getResource(), then this shouldn't be a problem unless called on the UI thread. nameResolverFactory = NameResolverProvider.asFactory(); } + if (authorityOverride != null) { + nameResolverFactory = + new OverrideAuthorityNameResolverFactory(nameResolverFactory, authorityOverride); + } List effectiveInterceptors = new ArrayList(this.interceptors); @@ -382,29 +382,6 @@ public Executor returnObject(Object returned) { } } - private static class AuthorityOverridingTransportFactory implements ClientTransportFactory { - final ClientTransportFactory factory; - final String authorityOverride; - - AuthorityOverridingTransportFactory( - ClientTransportFactory factory, String authorityOverride) { - this.factory = Preconditions.checkNotNull(factory, "factory should not be null"); - this.authorityOverride = Preconditions.checkNotNull( - authorityOverride, "authorityOverride should not be null"); - } - - @Override - public ConnectionClientTransport newClientTransport(SocketAddress serverAddress, - String authority, @Nullable String userAgent) { - return factory.newClientTransport(serverAddress, authorityOverride, userAgent); - } - - @Override - public void close() { - factory.close(); - } - } - private static class DirectAddressNameResolverFactory extends NameResolver.Factory { final SocketAddress address; final String authority; @@ -440,6 +417,60 @@ public String getDefaultScheme() { } } + /** + * A wrapper class that overrides the authority of a NameResolver, while preserving all other + * functionality. + */ + @VisibleForTesting + static class OverrideAuthorityNameResolverFactory extends NameResolver.Factory { + private final NameResolver.Factory delegate; + private final String authorityOverride; + + /** + * Constructor for the {@link NameResolver.Factory} + * + * @param delegate The actual underlying factory that will produce the a {@link NameResolver} + * @param authorityOverride The authority that will be returned by {@link + * NameResolver#getServiceAuthority()} + */ + OverrideAuthorityNameResolverFactory(NameResolver.Factory delegate, + String authorityOverride) { + this.delegate = delegate; + this.authorityOverride = authorityOverride; + } + + @Nullable + @Override + public NameResolver newNameResolver(URI targetUri, Attributes params) { + final NameResolver resolver = delegate.newNameResolver(targetUri, params); + // Do not wrap null values. We do not want to impede error signaling. + if (resolver == null) { + return null; + } + return new NameResolver() { + @Override + public String getServiceAuthority() { + return authorityOverride; + } + + @Override + public void start(Listener listener) { + resolver.start(listener); + } + + @Override + public void shutdown() { + resolver.shutdown(); + } + }; + } + + @Override + public String getDefaultScheme() { + return delegate.getDefaultScheme(); + } + } + /** * Returns the correctly typed version of the builder. */ diff --git a/core/src/test/java/io/grpc/internal/AbstractManagedChannelImplBuilderTest.java b/core/src/test/java/io/grpc/internal/AbstractManagedChannelImplBuilderTest.java index 3413e778155..c4df0bf5206 100644 --- a/core/src/test/java/io/grpc/internal/AbstractManagedChannelImplBuilderTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractManagedChannelImplBuilderTest.java @@ -32,8 +32,14 @@ package io.grpc.internal; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import io.grpc.Attributes; +import io.grpc.NameResolver; import java.net.InetSocketAddress; import java.net.URI; import java.util.concurrent.TimeUnit; @@ -98,4 +104,31 @@ public Builder usePlaintext(boolean value) { builder.idleTimeout(30, TimeUnit.SECONDS); assertEquals(TimeUnit.SECONDS.toMillis(30), builder.getIdleTimeoutMillis()); } + + @Test + public void overrideAuthorityNameResolverWrapsDelegateTest() { + NameResolver nameResolverMock = mock(NameResolver.class); + NameResolver.Factory wrappedFactory = mock(NameResolver.Factory.class); + when(wrappedFactory.newNameResolver(any(URI.class), any(Attributes.class))) + .thenReturn(nameResolverMock); + String override = "override:5678"; + NameResolver.Factory factory = + new AbstractManagedChannelImplBuilder.OverrideAuthorityNameResolverFactory(wrappedFactory, + override); + NameResolver nameResolver = factory.newNameResolver(URI.create("dns:///localhost:443"), + Attributes.EMPTY); + assertNotNull(nameResolver); + assertEquals(override, nameResolver.getServiceAuthority()); + } + + @Test + public void overrideAuthorityNameResolverWontWrapNullTest() { + NameResolver.Factory wrappedFactory = mock(NameResolver.Factory.class); + when(wrappedFactory.newNameResolver(any(URI.class), any(Attributes.class))).thenReturn(null); + NameResolver.Factory factory = + new AbstractManagedChannelImplBuilder.OverrideAuthorityNameResolverFactory(wrappedFactory, + "override:5678"); + assertEquals(null, + factory.newNameResolver(URI.create("dns:///localhost:443"), Attributes.EMPTY)); + } } diff --git a/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java b/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java index 8109c56921f..9f03113dbe5 100644 --- a/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java @@ -35,6 +35,7 @@ import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; +import io.grpc.ManagedChannel; import io.grpc.netty.InternalNettyChannelBuilder.OverrideAuthorityChecker; import io.grpc.netty.ProtocolNegotiators.TlsNegotiator; import io.netty.handler.ssl.SslContext; @@ -54,6 +55,37 @@ public class NettyChannelBuilderTest { @Rule public final ExpectedException thrown = ExpectedException.none(); private final SslContext noSslContext = null; + @Test + public void authorityIsReadable() { + NettyChannelBuilder builder = NettyChannelBuilder.forAddress("original", 1234); + assertEquals("original:1234", builder.build().authority()); + } + + @Test + public void overrideAuthorityIsReadableForAddress() { + NettyChannelBuilder builder = NettyChannelBuilder.forAddress("original", 1234); + overrideAuthorityIsReadableHelper(builder, "override:5678"); + } + + @Test + public void overrideAuthorityIsReadableForTarget() { + NettyChannelBuilder builder = NettyChannelBuilder.forTarget("original:1234"); + overrideAuthorityIsReadableHelper(builder, "override:5678"); + } + + @Test + public void overrideAuthorityIsReadableForSocketAddress() { + NettyChannelBuilder builder = NettyChannelBuilder.forAddress(new SocketAddress(){}); + overrideAuthorityIsReadableHelper(builder, "override:5678"); + } + + private void overrideAuthorityIsReadableHelper(NettyChannelBuilder builder, + String overrideAuthority) { + builder.overrideAuthority(overrideAuthority); + ManagedChannel channel = builder.build(); + assertEquals(overrideAuthority, channel.authority()); + } + @Test public void overrideAllowsInvalidAuthority() { NettyChannelBuilder builder = new NettyChannelBuilder(new SocketAddress(){}); diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java index 011f5071e99..f1d2ff73439 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java @@ -53,6 +53,30 @@ public class OkHttpChannelBuilderTest { @Rule public final ExpectedException thrown = ExpectedException.none(); + @Test + public void authorityIsReadable() { + OkHttpChannelBuilder builder = OkHttpChannelBuilder.forAddress("original", 1234); + assertEquals("original:1234", builder.build().authority()); + } + + @Test + public void overrideAuthorityIsReadableForAddress() { + OkHttpChannelBuilder builder = OkHttpChannelBuilder.forAddress("original", 1234); + overrideAuthorityIsReadableHelper(builder, "override:5678"); + } + + @Test + public void overrideAuthorityIsReadableForTarget() { + OkHttpChannelBuilder builder = OkHttpChannelBuilder.forTarget("original:1234"); + overrideAuthorityIsReadableHelper(builder, "override:5678"); + } + + private void overrideAuthorityIsReadableHelper(OkHttpChannelBuilder builder, + String overrideAuthority) { + builder.overrideAuthority(overrideAuthority); + assertEquals(overrideAuthority, builder.build().authority()); + } + @Test public void overrideAllowsInvalidAuthority() { OkHttpChannelBuilder builder = new OkHttpChannelBuilder("good", 1234) {