Skip to content

Commit

Permalink
Merge pull request #23 from eda-labs/urllib3
Browse files Browse the repository at this point in the history
Urllib3 and proxyhandling
  • Loading branch information
FloSch62 authored Jan 17, 2025
2 parents 0ada3b3 + 5fa0fbb commit 968db99
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 106 deletions.
76 changes: 43 additions & 33 deletions src/eda.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json
import logging

import requests
import yaml

from src.http_client import create_pool_manager

# configure logging
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -34,21 +35,24 @@ def __init__(self, hostname, username, password, verify):
self.version = None
self.transactions = []

self.http = create_pool_manager(url=self.url, verify=self.verify)

def login(self):
"""
Retrieves an access_token and refresh_token from the EDA API
"""
payload = {"username": self.username, "password": self.password}

response = self.post("auth/login", payload, False).json()
response = self.post("auth/login", payload, False)
response_data = json.loads(response.data.decode("utf-8"))

if "code" in response and response["code"] != 200:
if "code" in response_data and response_data["code"] != 200:
raise Exception(
f"Could not authenticate with EDA, error message: '{response['message']} {response['details']}'"
f"Could not authenticate with EDA, error message: '{response_data['message']} {response_data['details']}'"
)

self.access_token = response["access_token"]
self.refresh_token = response["refresh_token"]
self.access_token = response_data["access_token"]
self.refresh_token = response_data["refresh_token"]

def get_headers(self, requires_auth):
"""
Expand Down Expand Up @@ -88,11 +92,7 @@ def get(self, api_path, requires_auth=True):
url = f"{self.url}/{api_path}"
logger.info(f"Performing GET request to '{url}'")

return requests.get(
url,
verify=self.verify,
headers=self.get_headers(requires_auth),
)
return self.http.request("GET", url, headers=self.get_headers(requires_auth))

def post(self, api_path, payload, requires_auth=True):
"""
Expand All @@ -110,11 +110,11 @@ def post(self, api_path, payload, requires_auth=True):
"""
url = f"{self.url}/{api_path}"
logger.info(f"Performing POST request to '{url}'")
return requests.post(
return self.http.request(
"POST",
url,
verify=self.verify,
json=payload,
headers=self.get_headers(requires_auth),
body=json.dumps(payload).encode("utf-8"),
)

def is_up(self):
Expand All @@ -127,8 +127,9 @@ def is_up(self):
"""
logger.info("Checking whether EDA is up")
health = self.get("core/about/health", requires_auth=False)
logger.debug(health.json())
return health.json()["status"] == "UP"
health_data = json.loads(health.data.decode("utf-8"))
logger.debug(health_data)
return health_data["status"] == "UP"

def get_version(self):
"""
Expand All @@ -140,7 +141,9 @@ def get_version(self):
return self.version

logger.info("Getting EDA version")
version = self.get("core/about/version").json()["eda"]["version"].split("-")[0]
version_response = self.get("core/about/version")
version_data = json.loads(version_response.data.decode("utf-8"))
version = version_data["eda"]["version"].split("-")[0]
logger.info(f"EDA version is {version}")

# storing this to make the tool backwards compatible
Expand Down Expand Up @@ -254,18 +257,20 @@ def is_transaction_item_valid(self, item):
logger.info("Validating transaction item")

response = self.post("core/transaction/v1/validate", item)
if response.status_code == 204:
if response.status == 204:
logger.info("Validation successful")
return True

response = response.json()
response_data = json.loads(
response.data.decode("utf-8")
) # Need to decode response data

if "code" in response:
message = f"{response['message']}"
if "details" in response:
message = f"{message} - {response['details']}"
if "code" in response_data:
message = f"{response_data['message']}"
if "details" in response_data:
message = f"{message} - {response_data['details']}"
logger.warning(
f"While validating a transaction item, the following validation error was returned (code {response['code']}): '{message}'"
f"While validating a transaction item, the following validation error was returned (code {response_data['code']}): '{message}'"
)

return False
Expand Down Expand Up @@ -295,16 +300,21 @@ def commit_transaction(
logger.info(f"Committing transaction with {len(self.transactions)} item(s)")
logger.debug(json.dumps(payload, indent=4))

response = self.post("core/transaction/v1", payload).json()
if "id" not in response:
raise Exception(f"Could not find transaction ID in response {response}")
response = self.post("core/transaction/v1", payload)
response_data = json.loads(response.data.decode("utf-8"))
if "id" not in response_data:
raise Exception(
f"Could not find transaction ID in response {response_data}"
)

transactionId = response["id"]
transactionId = response_data["id"]

logger.info(f"Waiting for transaction with ID {transactionId} to complete")
result = self.get(
f"core/transaction/v1/details/{transactionId}?waitForComplete=true&failOnErrors=true"
).json()
result = json.loads(
self.get(
f"core/transaction/v1/details/{transactionId}?waitForComplete=true&failOnErrors=true"
).data.decode("utf-8")
)

if "code" in result:
message = f"{result['message']}"
Expand Down Expand Up @@ -348,7 +358,7 @@ def revert_transaction(self, transactionId):
).json()

response = self.post(f"core/transaction/v1/revert/{transactionId}", {})
result = response.json()
result = json.loads(response.data.decode("utf-8"))

if "code" in result and result["code"] != 0:
message = f"{result['message']}"
Expand Down Expand Up @@ -392,7 +402,7 @@ def restore_transaction(self, transactionId):
).json()

response = self.post(f"core/transaction/v1/restore/{restore_point}", {})
result = response.json()
result = json.loads(response.data.decode("utf-8"))

if "code" in result and result["code"] != 0:
message = f"{result['message']}"
Expand Down
39 changes: 0 additions & 39 deletions src/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import sys
import tempfile

import requests
from jinja2 import Environment, FileSystemLoader

import src.topology as topology
Expand Down Expand Up @@ -98,44 +97,6 @@ def apply_manifest_via_kubectl(yaml_str: str, namespace: str = "eda-system"):
finally:
os.remove(tmp_path)


def get_artifact_from_github(owner: str, repo: str, version: str, asset_filter=None):
"""
Queries GitHub for a specific release artifact.
Parameters
----------
owner: GitHub repository owner
repo: GitHub repository name
version: Version tag to search for (without 'v' prefix)
asset_filter: Optional function(asset_name) -> bool to filter assets
Returns
-------
Tuple of (filename, download_url) or (None, None) if not found
"""
tag = f"v{version}" # Assume GitHub tags are prefixed with 'v'
url = f"https://api.github.com/repos/{owner}/{repo}/releases/tags/{tag}"

logger.info(f"Querying GitHub release {tag} from {owner}/{repo}")
resp = requests.get(url)

if resp.status_code != 200:
logger.warning(f"Failed to fetch release for {tag}, status={resp.status_code}")
return None, None

data = resp.json()
assets = data.get("assets", [])

for asset in assets:
name = asset.get("name", "")
if asset_filter is None or asset_filter(name):
return name, asset.get("browser_download_url")

# No matching asset found
return None, None


def normalize_name(name: str) -> str:
"""
Returns a Kubernetes-compliant name by:
Expand Down
139 changes: 139 additions & 0 deletions src/http_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import logging
import os
import re
import urllib3
from urllib.parse import urlparse

logger = logging.getLogger(__name__)


def get_proxy_settings():
"""
Get proxy settings from environment variables.
Handles both upper and lowercase variants.
Returns
-------
tuple: (http_proxy, https_proxy, no_proxy)
"""
# Check both variants
http_upper = os.environ.get("HTTP_PROXY")
http_lower = os.environ.get("http_proxy")
https_upper = os.environ.get("HTTPS_PROXY")
https_lower = os.environ.get("https_proxy")
no_upper = os.environ.get("NO_PROXY")
no_lower = os.environ.get("no_proxy")

# Log if both variants are set
if http_upper and http_lower and http_upper != http_lower:
logger.warning(
f"Both HTTP_PROXY ({http_upper}) and http_proxy ({http_lower}) are set with different values. Using HTTP_PROXY."
)

if https_upper and https_lower and https_upper != https_lower:
logger.warning(
f"Both HTTPS_PROXY ({https_upper}) and https_proxy ({https_lower}) are set with different values. Using HTTPS_PROXY."
)

if no_upper and no_lower and no_upper != no_lower:
logger.warning(
f"Both NO_PROXY ({no_upper}) and no_proxy ({no_lower}) are set with different values. Using NO_PROXY."
)

# Use uppercase variants if set, otherwise lowercase
http_proxy = http_upper if http_upper is not None else http_lower
https_proxy = https_upper if https_upper is not None else https_lower
no_proxy = no_upper if no_upper is not None else no_lower or ""

return http_proxy, https_proxy, no_proxy


def should_bypass_proxy(url, no_proxy=None):
"""
Check if the given URL should bypass proxy based on NO_PROXY settings.
Parameters
----------
url : str
The URL to check
no_proxy : str, optional
The NO_PROXY string to use. If None, gets from environment.
Returns
-------
bool
True if proxy should be bypassed, False otherwise
"""
if no_proxy is None:
_, _, no_proxy = get_proxy_settings()

if not no_proxy:
return False

parsed_url = urlparse(url if "//" in url else f"http://{url}")
hostname = parsed_url.hostname

if not hostname:
return False

# Split NO_PROXY into parts and clean them
no_proxy_parts = [p.strip() for p in no_proxy.split(",") if p.strip()]

for no_proxy_value in no_proxy_parts:
# Convert .foo.com to foo.com
if no_proxy_value.startswith("."):
no_proxy_value = no_proxy_value[1:]

# Handle IP addresses and CIDR notation
if re.match(r"^(?:\d{1,3}\.){3}\d{1,3}(?:/\d{1,2})?$", no_proxy_value):
# TODO: Implement CIDR matching if needed
if hostname == no_proxy_value:
return True
# Handle domain names with wildcards
else:
pattern = re.escape(no_proxy_value).replace(r"\*", ".*")
if re.match(f"^{pattern}$", hostname, re.IGNORECASE):
return True

return False


def create_pool_manager(url=None, verify=True):
"""
Create a PoolManager or ProxyManager based on environment settings and URL
Parameters
----------
url : str, optional
The URL that will be accessed with this pool manager
If provided, NO_PROXY rules will be checked
verify : bool
Whether to verify SSL certificates
Returns
-------
urllib3.PoolManager or urllib3.ProxyManager
"""
http_proxy, https_proxy, no_proxy = get_proxy_settings()

# Check if this URL should bypass proxy
if url and should_bypass_proxy(url, no_proxy):
logger.debug(f"URL {url} matches NO_PROXY rules, creating direct PoolManager")
return urllib3.PoolManager(
cert_reqs="CERT_REQUIRED" if verify else "CERT_NONE",
retries=urllib3.Retry(3),
)

proxy_url = https_proxy or http_proxy
if proxy_url:
logger.debug(f"Creating ProxyManager with proxy URL: {proxy_url}")
return urllib3.ProxyManager(
proxy_url,
cert_reqs="CERT_REQUIRED" if verify else "CERT_NONE",
retries=urllib3.Retry(3),
)

logger.debug("Creating PoolManager without proxy")
return urllib3.PoolManager(
cert_reqs="CERT_REQUIRED" if verify else "CERT_NONE", retries=urllib3.Retry(3)
)
Loading

0 comments on commit 968db99

Please sign in to comment.