Skip to content

Commit

Permalink
fix: get channel target for a gRPC request (#1339)
Browse files Browse the repository at this point in the history
  • Loading branch information
ohmayr authored Jan 29, 2025
1 parent 59965a4 commit 16ea766
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 6 deletions.
24 changes: 21 additions & 3 deletions tests/unit/pubsub_v1/publisher/test_publisher_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@
typed_flaky = cast(Callable[[C], C], flaky(max_runs=5, min_passes=1))


# NOTE: This interceptor is required to create an intercept channel.
class _PublisherClientGrpcInterceptor(
grpc.UnaryUnaryClientInterceptor,
):
def intercept_unary_unary(self, continuation, client_call_details, request):
pass


def _assert_retries_equal(retry, retry2):
# Retry instances cannot be directly compared, because their predicates are
# different instances of the same function. We thus manually compare their other
Expand Down Expand Up @@ -416,17 +424,27 @@ def init(self, *args, **kwargs):
assert client.transport._ssl_channel_credentials == mock_ssl_creds


def test_init_emulator(monkeypatch):
def test_init_emulator(monkeypatch, creds):
monkeypatch.setenv("PUBSUB_EMULATOR_HOST", "/foo/bar:123")
# NOTE: When the emulator host is set, a custom channel will be used, so
# no credentials (mock ot otherwise) can be passed in.
client = publisher.Client()

# TODO(https://github.com/grpc/grpc/issues/38519): Workaround to create an intercept
# channel (for forwards compatibility) with a channel created by the publisher client
# where target is set to the emulator host.
channel = publisher.Client().transport.grpc_channel
interceptor = _PublisherClientGrpcInterceptor()
intercept_channel = grpc.intercept_channel(channel, interceptor)
transport = publisher.Client.get_transport_class("grpc")(
credentials=creds, channel=intercept_channel
)
client = publisher.Client(transport=transport)

# Establish that a gRPC request would attempt to hit the emulator host.
#
# Sadly, there seems to be no good way to do this without poking at
# the private API of gRPC.
channel = client._transport.publish._channel
channel = client._transport.publish._thunk("")._channel
# Behavior to include dns prefix changed in gRPCv1.63
grpc_major, grpc_minor = [int(part) for part in grpc.__version__.split(".")[0:2]]
if grpc_major > 1 or (grpc_major == 1 and grpc_minor >= 63):
Expand Down
24 changes: 21 additions & 3 deletions tests/unit/pubsub_v1/subscriber/test_subscriber_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@
from google.pubsub_v1.types import PubsubMessage


# NOTE: This interceptor is required to create an intercept channel.
class _SubscriberClientGrpcInterceptor(
grpc.UnaryUnaryClientInterceptor,
):
def intercept_unary_unary(self, continuation, client_call_details, request):
pass


def test_init_default_client_info(creds):
client = subscriber.Client(credentials=creds)

Expand Down Expand Up @@ -119,17 +127,27 @@ def init(self, *args, **kwargs):
assert client.transport._ssl_channel_credentials == mock_ssl_creds


def test_init_emulator(monkeypatch):
def test_init_emulator(monkeypatch, creds):
monkeypatch.setenv("PUBSUB_EMULATOR_HOST", "/baz/bacon:123")
# NOTE: When the emulator host is set, a custom channel will be used, so
# no credentials (mock ot otherwise) can be passed in.
client = subscriber.Client()

# TODO(https://github.com/grpc/grpc/issues/38519): Workaround to create an intercept
# channel (for forwards compatibility) with a channel created by the publisher client
# where target is set to the emulator host.
channel = subscriber.Client().transport.grpc_channel
interceptor = _SubscriberClientGrpcInterceptor()
intercept_channel = grpc.intercept_channel(channel, interceptor)
transport = subscriber.Client.get_transport_class("grpc")(
credentials=creds, channel=intercept_channel
)
client = subscriber.Client(transport=transport)

# Establish that a gRPC request would attempt to hit the emulator host.
#
# Sadly, there seems to be no good way to do this without poking at
# the private API of gRPC.
channel = client._transport.pull._channel
channel = client._transport.pull._thunk("")._channel
# Behavior to include dns prefix changed in gRPCv1.63
grpc_major, grpc_minor = [int(part) for part in grpc.__version__.split(".")[0:2]]
if grpc_major > 1 or (grpc_major == 1 and grpc_minor >= 63):
Expand Down

0 comments on commit 16ea766

Please sign in to comment.