Skip to content

Commit

Permalink
feat: Refactor OAuth2 client credentials channel creation
Browse files Browse the repository at this point in the history
This commit refactors the `create_oauth2_client_credentials_channel` function in the `pyzeebe.channel.oauth_channel` module. The function is responsible for creating a channel connected to a Camunda Cloud cluster using OAuth2 client credentials for authentication. It takes various parameters such as the target address, client ID, client secret, authorization server, scope, audience, and expiration time. The function returns a GRPC channel connected to the Zeebe Gateway.

The refactoring improves the code structure and readability of the function, making it easier to maintain and understand. This change is necessary to provide support for authenticating with Camunda Cloud using OAuth2 client credentials.

Co-authored-by: dependabot[bot] <[email protected]>
  • Loading branch information
felicijus and dependabot[bot] committed Aug 21, 2024
1 parent 01cbd5c commit 07140cf
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 87 deletions.
13 changes: 8 additions & 5 deletions pyzeebe/channel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# from pyzeebe.channel.camunda_cloud_channel import create_camunda_cloud_channel # FIXME: could be removed
from pyzeebe.channel.insecure_channel import create_insecure_channel
from pyzeebe.channel.oauth_channel import (
create_camunda_cloud_channel,
create_oauth2_client_credentials_channel,
from pyzeebe.channel.camunda_cloud_channel import (
create_camunda_cloud_channel, # FIXME: could be removed
)
from pyzeebe.channel.insecure_channel import create_insecure_channel

# from pyzeebe.channel.oauth_channel import (
# create_camunda_cloud_channel,
# create_oauth2_client_credentials_channel,
# )
from pyzeebe.channel.secure_channel import create_secure_channel
56 changes: 40 additions & 16 deletions pyzeebe/channel/oauth_channel.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
from functools import partial
from typing import Callable, Optional
from typing import Optional

import grpc
from grpc.aio._typing import ChannelArgumentType
from oauthlib import oauth2
from requests_oauthlib import OAuth2Session

from pyzeebe.credentials.oauth import (
Oauth2ClientCredentialsMetadataPlugin,
OAuth2MetadataPlugin,
)
from pyzeebe.credentials.oauth import Oauth2ClientCredentialsMetadataPlugin


def create_oauth2_client_credentials_channel(
Expand All @@ -19,8 +14,10 @@ def create_oauth2_client_credentials_channel(
authorization_server: str,
scope: Optional[str] = None,
audience: Optional[str] = None,
expire_in: Optional[int] = None,
channel_credentials: grpc.ChannelCredentials = grpc.ssl_channel_credentials(),
channel_options: Optional[ChannelArgumentType] = None,
leeway: int = 60,
expire_in: Optional[int] = None,
) -> grpc.aio.Channel:
"""
Create a gRPC channel for connecting to Camunda 8 (Self-Managed) with OAuth2ClientCredentials.
Expand All @@ -38,12 +35,14 @@ def create_oauth2_client_credentials_channel(
scope (Optional[str]): The scope of the access request. Defaults to None.
audience (Optional[str]): The audience for authentication. Defaults to None.
expire_in (Optional[int]): The number of seconds the token is valid for. Defaults to None.
Should only be used if the token does not contain an "expires_in" attribute.
channel_credentials (grpc.ChannelCredentials): The gRPC channel credentials. Defaults to grpc.ssl_channel_credentials().
channel_options (Optional[ChannelArgumentType], optional): Additional options for the gRPC channel. Defaults to None.
See https://grpc.github.io/grpc/python/glossary.html#term-channel_arguments
leeway (int): The number of seconds to consider the token as expired before the actual expiration time. Defaults to 60.
expire_in (Optional[int]): The number of seconds the token is valid for. Defaults to None.
Should only be used if the token does not contain an "expires_in" attribute.
Returns:
grpc.aio.Channel: A gRPC channel connected to the Zeebe Gateway.
Expand All @@ -57,10 +56,20 @@ def create_oauth2_client_credentials_channel(
authorization_server=authorization_server,
scope=scope,
audience=audience,
leeway=leeway,
expire_in=expire_in,
)

channel = oauth2_client_credentials.channel(target=target, channel_options=channel_options)
call_credentials: grpc.CallCredentials = grpc.metadata_call_credentials(oauth2_client_credentials)
# channel_credentials: grpc.ChannelCredentials = channel_credentials or grpc.ssl_channel_credentials()
composite_credentials: grpc.ChannelCredentials = grpc.composite_channel_credentials(
channel_credentials, call_credentials
)

channel: grpc.aio.Channel = grpc.aio.secure_channel(
target=target, credentials=composite_credentials, options=channel_options
)

return channel


Expand All @@ -72,8 +81,10 @@ def create_camunda_cloud_channel(
scope: str = "Zeebe",
authorization_server: str = "https://login.cloud.camunda.io/oauth/token",
audience: str = "zeebe.camunda.io",
expire_in: Optional[int] = None,
channel_credentials: grpc.ChannelCredentials = grpc.ssl_channel_credentials(),
channel_options: Optional[ChannelArgumentType] = None,
leeway: int = 60,
expire_in: Optional[int] = None,
) -> grpc.aio.Channel:
"""
Create a gRPC channel for connecting to Camunda 8 Cloud (SaaS).
Expand All @@ -88,11 +99,14 @@ def create_camunda_cloud_channel(
to the client after successfully authenticating the client. Defaults to "https://login.cloud.camunda.io/oauth/token".
audience (Optional[str]): The audience for authentication. Defaults to "zeebe.camunda.io".
expire_in (Optional[int]): The expiration time for the token. Defaults to None.
channel_credentials (grpc.ChannelCredentials): The gRPC channel credentials. Defaults to grpc.ssl_channel_credentials().
channel_options (Optional[ChannelArgumentType], optional): Additional options for the gRPC channel. Defaults to None.
See https://grpc.github.io/grpc/python/glossary.html#term-channel_arguments
leeway (int): The number of seconds to consider the token as expired before the actual expiration time. Defaults to 60.
expire_in (Optional[int]): The number of seconds the token is valid for. Defaults to None.
Should only be used if the token does not contain an "expires_in" attribute.
Returns:
grpc.aio.Channel: The gRPC channel for connecting to Camunda Cloud.
"""
Expand All @@ -105,6 +119,7 @@ def create_camunda_cloud_channel(
authorization_server=authorization_server,
scope=scope,
audience=audience,
leeway=leeway,
expire_in=expire_in,
)

Expand All @@ -118,5 +133,14 @@ def create_camunda_cloud_channel(
)
oauth2_client_credentials._func_retrieve_token = func

channel = oauth2_client_credentials.channel(target=target, channel_options=channel_options)
call_credentials: grpc.CallCredentials = grpc.metadata_call_credentials(oauth2_client_credentials)
# channel_credentials: grpc.ChannelCredentials = channel_credentials or grpc.ssl_channel_credentials()
composite_credentials: grpc.ChannelCredentials = grpc.composite_channel_credentials(
channel_credentials, call_credentials
)

channel: grpc.aio.Channel = grpc.aio.secure_channel(
target=target, credentials=composite_credentials, options=channel_options
)

return channel
64 changes: 20 additions & 44 deletions pyzeebe/credentials/oauth.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import json
import logging
import math
import time
import timeit
from collections.abc import Callable
from functools import partial
from typing import Optional
from typing import Any, Optional

import grpc
import requests
from grpc._auth import _sign_request
from grpc.aio._typing import ChannelArgumentType
from oauthlib import oauth2
from requests_oauthlib import OAuth2Session

Expand All @@ -29,7 +28,8 @@ class OAuth2MetadataPlugin(grpc.AuthMetadataPlugin):
def __init__(
self,
oauth2session: OAuth2Session,
func_retrieve_token: Callable,
func_retrieve_token: Callable[[], Any],
leeway: int = 60,
expire_in: Optional[int] = None,
) -> None:
"""AuthMetadataPlugin for OAuth2 Authentication.
Expand All @@ -40,11 +40,14 @@ def __init__(
oauth2session (OAuth2Session): The OAuth2Session object.
func_fetch_token (Callable): The function to fetch the token.
expire_in (Optional[int]): Only used if the token does not contain an "expires_in" attribute. The number of seconds the token is valid for.
leeway (int): The number of seconds to consider the token as expired before the actual expiration time. Defaults to 60.
expire_in (Optional[int]): The number of seconds the token is valid for. Defaults to None.
Should only be used if the token does not contain an "expires_in" attribute.
"""
self._oauth: OAuth2Session = oauth2session
self._func_retrieve_token: Callable = func_retrieve_token
self._func_retrieve_token: Callable[[], Any] = func_retrieve_token

self._leeway: int = leeway
self._expires_in: Optional[int] = expire_in
if self._expires_in is not None:
# NOTE: "expires_in" is only RECOMMENDED
Expand All @@ -55,7 +58,7 @@ def __call__(
self,
context: grpc.AuthMetadataContext,
callback: grpc.AuthMetadataPluginCallback,
):
) -> None:
start_time = timeit.default_timer()

try:
Expand Down Expand Up @@ -86,7 +89,7 @@ def is_token_expired(self) -> bool:
# https://datatracker.ietf.org/doc/html/rfc6749#appendix-A
# https://oauthlib.readthedocs.io/en/latest/_modules/oauthlib/oauth2/rfc6749/clients/base.html?highlight=expires_at#
expires_at = self._oauth.token.get("expires_at", 0)
if time.time() > math.floor(expires_at): # FIXME: math.floor, token could be already expired
if time.time() > (expires_at - self._leeway):
return True

return False
Expand All @@ -101,7 +104,7 @@ def retrieve_token(self) -> None:
logger.error(str(e))
raise e

def _no_expiration(self, r):
def _no_expiration(self, r: requests.Response) -> requests.Response:
"""
Sets the expiration time for the token if it is not provided in the response.
Expand Down Expand Up @@ -135,6 +138,7 @@ def __init__(
authorization_server: str,
scope: Optional[str] = None,
audience: Optional[str] = None,
leeway: int = 60,
expire_in: Optional[int] = None,
):
"""AuthMetadataPlugin for OAuth2 Client Credentials Authentication based on Oauth2MetadataPlugin.
Expand All @@ -150,16 +154,18 @@ def __init__(
scope (Optional[str]): The scope of the access request. Defaults to None.
audience (Optional[str]): The audience for authentication. Defaults to None.
leeway (int): The number of seconds to consider the token as expired before the actual expiration time. Defaults to 60.
expire_in (Optional[int]): The number of seconds the token is valid for. Defaults to None.
Should only be used if the token does not contain an "expires_in" attribute.
"""

self.client_id: str = client_id
self.client_secret: str = client_secret
self.authorization_server: str = authorization_server
self.scope: str | None = scope
self.audience: str | None = audience
self.expire_in: int | None = expire_in
self.scope: Optional[str] = scope
self.audience: Optional[str] = audience
self.leeway: int = leeway
self.expire_in: Optional[int] = expire_in

client = oauth2.BackendApplicationClient(client_id=self.client_id, scope=self.scope)
oauth2session = OAuth2Session(client=client)
Expand All @@ -171,36 +177,6 @@ def __init__(
audience=self.audience,
)

super().__init__(oauth2session=oauth2session, func_retrieve_token=func, expire_in=self.expire_in)

def call_credentials(self) -> grpc.CallCredentials:
"""
Creates and returns the gRPC call credentials.
Returns:
grpc.CallCredentials: The gRPC call credentials.
"""
return grpc.metadata_call_credentials(self)

def channel(self, target: str, channel_options: Optional[ChannelArgumentType] = None) -> grpc.aio.Channel:
"""
Creates and returns a gRPC channel with secure credentials.
Args:
target (str): The target address of the channel.
channel_options (Optional[ChannelArgumentType], optional): Additional options for the gRPC channel. Defaults to None.
See https://grpc.github.io/grpc/python/glossary.html#term-channel_arguments
Returns:
grpc.aio.Channel: The created gRPC channel.
"""

call_credentials: grpc.CallCredentials = self.call_credentials()
channel_credentials: grpc.ChannelCredentials = grpc.ssl_channel_credentials()

composite_credentials: grpc.ChannelCredentials = grpc.composite_channel_credentials(
channel_credentials, call_credentials
)

channel: grpc.aio.Channel = grpc.aio.secure_channel(
target=target, credentials=composite_credentials, options=channel_options
super().__init__(
oauth2session=oauth2session, func_retrieve_token=func, leeway=self.leeway, expire_in=self.expire_in
)
return channel
75 changes: 53 additions & 22 deletions tests/unit/credentials/oauth_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,26 +105,27 @@ def test_initialization(self, oauth2mp: OAuth2MetadataPlugin):

assert oauth2mp._oauth.client_id == "test_id"
assert oauth2mp._oauth.scope == "test_scope_a test_scope_b"
assert oauth2mp._leeway == 60

assert oauth2mp._func_retrieve_token.keywords["token_url"] == "https://auth.server"
assert oauth2mp._func_retrieve_token.keywords["client_secret"] == "test_secret"
assert oauth2mp._func_retrieve_token.keywords["audience"] == "test_audience"

@pytest.mark.parametrize(
"authorized, token, expected",
"authorized, token, is_expired",
[
(True, {"expires_at": time.time() + 3600}, False),
(True, {"expires_at": time.time() - 3600}, True),
(False, {"expires_at": time.time() + 3600}, True),
(False, {"expires_at": time.time() - 3600}, True),
(True, {"expires_at": time.time() + 300}, False),
(True, {"expires_at": time.time()}, True),
(False, {"expires_at": time.time() + 300}, True),
(False, {"expires_at": time.time()}, True),
],
)
@mock.patch("pyzeebe.credentials.oauth.OAuth2Session.authorized", new_callable=PropertyMock)
def test_is_token_expired(self, mock_authorized, authorized, token, expected, oauth2mp):
def test_is_token_expired(self, mock_authorized, authorized, token, is_expired, oauth2mp):
mock_authorized.return_value = authorized

oauth2mp._oauth.token = token
assert oauth2mp.is_token_expired() is expected
assert oauth2mp.is_token_expired() is is_expired

@pytest.fixture()
def mock_response(self, token):
Expand Down Expand Up @@ -290,6 +291,51 @@ def test_no_expiration_with_token(
mock_dumps.assert_called_once_with(token)


class TestOAuth2MetadataPluginLeeway:

# NOTE: Following tests in scenario where "leeway" needs to be considered
@pytest.fixture(autouse=True)
def oauth2mp_leeway(self, oauth2session, func):

return OAuth2MetadataPlugin(
oauth2session=oauth2session,
func_retrieve_token=func,
leeway=30,
)

def test_initialization(self, oauth2mp_leeway: OAuth2MetadataPlugin):

assert isinstance(oauth2mp_leeway._oauth, OAuth2Session)
assert isinstance(oauth2mp_leeway._func_retrieve_token, partial)
assert isinstance(oauth2mp_leeway._func_retrieve_token, Callable)

assert oauth2mp_leeway._oauth.client_id == "test_id"
assert oauth2mp_leeway._oauth.scope == "test_scope_a test_scope_b"
assert oauth2mp_leeway._leeway == 30

assert oauth2mp_leeway._func_retrieve_token.keywords["token_url"] == "https://auth.server"
assert oauth2mp_leeway._func_retrieve_token.keywords["client_secret"] == "test_secret"
assert oauth2mp_leeway._func_retrieve_token.keywords["audience"] == "test_audience"

@pytest.mark.parametrize(
"authorized, token, is_expired",
[
(True, {"expires_at": time.time() + 300}, False),
(True, {"expires_at": time.time() + 30}, True), # NOTE: leeway is 30
(True, {"expires_at": time.time()}, True),
(False, {"expires_at": time.time() + 300}, True),
(False, {"expires_at": time.time() + 30}, True), # NOTE: leeway is 30
(False, {"expires_at": time.time()}, True),
],
)
@mock.patch("pyzeebe.credentials.oauth.OAuth2Session.authorized", new_callable=PropertyMock)
def test_is_token_expired(self, mock_authorized, authorized, token, is_expired, oauth2mp_leeway):
mock_authorized.return_value = authorized

oauth2mp_leeway._oauth.token = token
assert oauth2mp_leeway.is_token_expired() is is_expired


class TestOauth2ClientCredentialsMetadataPlugin:

@pytest.fixture(autouse=True)
Expand All @@ -314,18 +360,3 @@ def test_initialization(self, oauth2ccmp: Oauth2ClientCredentialsMetadataPlugin)
assert oauth2ccmp._func_retrieve_token.keywords["token_url"] == "https://auth.server"
assert oauth2ccmp._func_retrieve_token.keywords["client_secret"] == "test_secret"
assert oauth2ccmp._func_retrieve_token.keywords["audience"] == "test_audience"

def test_call_credentials(self, oauth2ccmp: Oauth2ClientCredentialsMetadataPlugin):
assert isinstance(oauth2ccmp.call_credentials(), grpc.CallCredentials)

def test_channel(self, oauth2ccmp: Oauth2ClientCredentialsMetadataPlugin):

channel = oauth2ccmp.channel(target="zeebe-grpc-camunda.test.de:443")
assert isinstance(channel, grpc.aio.Channel)

def test_channel_options(self, oauth2ccmp: Oauth2ClientCredentialsMetadataPlugin):

channel_options: ChannelArgumentType = (("grpc.ssl_target_name_override", "test_override"),)
channel = oauth2ccmp.channel(target="zeebe-grpc-camunda.test.de:443", channel_options=channel_options)

assert isinstance(channel, grpc.aio.Channel)

0 comments on commit 07140cf

Please sign in to comment.