Skip to content

Commit

Permalink
Add TcpUpstreamConnectionHandler class (#760)
Browse files Browse the repository at this point in the history
* Add `TcpUpstreamConnectionHandler` which can be used as standalone or as mixin

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* `TcpUpstreamConnectionHandler` is now an abstract class

* Fix mypy

* `nitpick_ignore` the `proxy.core.base.tcp_upstream.TcpUpstreamConnectionHandler` class

* Add mypy exception for now

* Fix flake

* Fix docstring

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
abhinavsingh and pre-commit-ci[bot] authored Nov 19, 2021
1 parent 44c095a commit 3cfce52
Show file tree
Hide file tree
Showing 11 changed files with 187 additions and 121 deletions.
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@
(_any_role, 'HttpProtocolHandler'),
(_any_role, 'multiprocessing.Manager'),
(_any_role, 'work_klass'),
(_any_role, 'proxy.core.base.tcp_upstream.TcpUpstreamConnectionHandler'),
(_py_class_role, 'CacheStore'),
(_py_class_role, 'HttpParser'),
(_py_class_role, 'HttpProtocolHandlerPlugin'),
Expand Down
2 changes: 2 additions & 0 deletions proxy/core/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
"""
from .tcp_server import BaseTcpServerHandler
from .tcp_tunnel import BaseTcpTunnelHandler
from .tcp_upstream import TcpUpstreamConnectionHandler

__all__ = [
'BaseTcpServerHandler',
'BaseTcpTunnelHandler',
'TcpUpstreamConnectionHandler',
]
1 change: 0 additions & 1 deletion proxy/core/base/tcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ class BaseTcpServerHandler(Work):
a. handle_data(data: memoryview) implementation
b. Optionally, also implement other Work method
e.g. initialize, is_inactive, shutdown
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
Expand Down
102 changes: 102 additions & 0 deletions proxy/core/base/tcp_upstream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# -*- coding: utf-8 -*-
"""
proxy.py
~~~~~~~~
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
Network monitoring, controls & Application development, testing, debugging.
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
from abc import ABC, abstractmethod

import ssl
import socket
import logging

from typing import Tuple, List, Optional, Any

from ...common.types import Readables, Writables
from ...core.connection import TcpServerConnection

logger = logging.getLogger(__name__)


class TcpUpstreamConnectionHandler(ABC):
""":class:`~proxy.core.base.TcpUpstreamConnectionHandler` can
be used to insert an upstream server connection lifecycle within
asynchronous proxy.py lifecycle.
Call `initialize_upstream` to initialize the upstream connection object.
Then, directly use ``self.upstream`` object within your class.
.. spelling::
tcp
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
# This is currently a hack, see comments below for rationale,
# will be fixed later.
super().__init__(*args, **kwargs) # type: ignore
self.upstream: Optional[TcpServerConnection] = None
# TODO: Currently, :class:`~proxy.core.base.TcpUpstreamConnectionHandler`
# is used within :class:`~proxy.plugin.ReverseProxyPlugin` and
# :class:`~proxy.plugin.ProxyPoolPlugin`.
#
# For both of which we expect a 4-tuple as arguments
# containing (uuid, flags, client, event_queue).
# We really don't need the rest of the args here.
# May be uuid? May be event_queue in the future.
# But certainly we don't not client here.
# A separate tunnel class must be created which handles
# client connection too.
#
# Both :class:`~proxy.plugin.ReverseProxyPlugin` and
# :class:`~proxy.plugin.ProxyPoolPlugin` are currently
# calling client queue within `handle_upstream_data` callback.
#
# This can be abstracted out too.
self.server_recvbuf_size = args[1].server_recvbuf_size
self.total_size = 0

@abstractmethod
def handle_upstream_data(self, raw: memoryview) -> None:
pass

def initialize_upstream(self, addr: str, port: int) -> None:
self.upstream = TcpServerConnection(addr, port)

def get_descriptors(self) -> Tuple[List[socket.socket], List[socket.socket]]:
if not self.upstream:
return [], []
return [self.upstream.connection], [self.upstream.connection] if self.upstream.has_buffer() else []

def read_from_descriptors(self, r: Readables) -> bool:
if self.upstream and self.upstream.connection in r:
try:
raw = self.upstream.recv(self.server_recvbuf_size)
if raw is not None:
self.total_size += len(raw)
self.handle_upstream_data(raw)
else:
return True # Teardown because upstream proxy closed the connection
except ssl.SSLWantReadError:
logger.info('Upstream SSLWantReadError, will retry')
return False
except ConnectionResetError:
logger.debug('Connection reset by upstream')
return True
return False

def write_to_descriptors(self, w: Writables) -> bool:
if self.upstream and self.upstream.connection in w and self.upstream.has_buffer():
try:
self.upstream.flush()
except ssl.SSLWantWriteError:
logger.info('Upstream SSLWantWriteError, will retry')
return False
except BrokenPipeError:
logger.debug('BrokenPipeError when flushing to upstream')
return True
return False
10 changes: 0 additions & 10 deletions proxy/http/server/pac_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from .plugin import HttpWebServerBasePlugin
from .protocols import httpProtocolTypes

from ..websocket import WebsocketFrame
from ..parser import HttpParser

from ...common.utils import bytes_, text_, build_http_response
Expand Down Expand Up @@ -64,15 +63,6 @@ def handle_request(self, request: HttpParser) -> None:
if self.flags.pac_file and self.pac_file_response:
self.client.queue(self.pac_file_response)

def on_websocket_open(self) -> None:
pass # pragma: no cover

def on_websocket_message(self, frame: WebsocketFrame) -> None:
pass # pragma: no cover

def on_client_connection_close(self) -> None:
pass # pragma: no cover

def cache_pac_file_response(self) -> None:
if self.flags.pac_file:
try:
Expand Down
12 changes: 8 additions & 4 deletions proxy/http/server/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,19 @@ def on_client_connection_close(self) -> None:
"""Client has closed the connection, do any clean up task now."""
pass

@abstractmethod
# No longer abstract since v2.4.0
#
# @abstractmethod
def on_websocket_open(self) -> None:
"""Called when websocket handshake has finished."""
raise NotImplementedError() # pragma: no cover
pass # pragma: no cover

@abstractmethod
# No longer abstract since v2.4.0
#
# @abstractmethod
def on_websocket_message(self, frame: WebsocketFrame) -> None:
"""Handle websocket frame."""
raise NotImplementedError() # pragma: no cover
return None # pragma: no cover

# Deprecated since v2.4.0
#
Expand Down
2 changes: 0 additions & 2 deletions proxy/http/server/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,5 @@ def on_client_connection_close(self) -> None:
if not log_handled:
self.access_log(context)

# TODO: Allow plugins to customize access_log, similar
# to how proxy server plugins are able to do it.
def access_log(self, context: Dict[str, Any]) -> None:
logger.info(DEFAULT_WEB_ACCESS_LOG_FORMAT.format_map(context))
17 changes: 16 additions & 1 deletion proxy/http/url.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Optional, Tuple

from ..common.constants import COLON, SLASH
from ..common.utils import text_


class Url:
Expand All @@ -36,6 +37,18 @@ def __init__(
self.port: Optional[int] = port
self.remainder: Optional[bytes] = remainder

def __str__(self) -> str:
url = ''
if self.scheme:
url += '{0}://'.format(text_(self.scheme))
if self.hostname:
url += text_(self.hostname)
if self.port:
url += ':{0}'.format(self.port)
if self.remainder:
url += text_(self.remainder)
return url

@classmethod
def from_bytes(cls, raw: bytes) -> 'Url':
"""A URL within proxy.py core can have several styles,
Expand All @@ -57,7 +70,9 @@ def from_bytes(cls, raw: bytes) -> 'Url':
return cls(remainder=raw)
if sraw.startswith('https://') or sraw.startswith('http://'):
is_https = sraw.startswith('https://')
rest = raw[len(b'https://'):] if is_https else raw[len(b'http://'):]
rest = raw[len(b'https://'):] \
if is_https \
else raw[len(b'http://'):]
parts = rest.split(SLASH)
host, port = Url.parse_host_and_port(parts[0])
return cls(
Expand Down
51 changes: 11 additions & 40 deletions proxy/plugin/proxy_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,18 @@
:license: BSD, see LICENSE for more details.
"""
import random
import socket
import logging

from typing import Dict, List, Optional, Any, Tuple
from typing import Dict, List, Optional, Any

from ..common.flag import flags
from ..common.types import Readables, Writables

from ..http import Url, httpMethods
from ..http.parser import HttpParser
from ..http.exception import HttpProtocolException
from ..http.proxy import HttpProxyBasePlugin

from ..core.connection.server import TcpServerConnection
from ..core.base import TcpUpstreamConnectionHandler

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -54,50 +52,21 @@
)


class ProxyPoolPlugin(HttpProxyBasePlugin):
class ProxyPoolPlugin(TcpUpstreamConnectionHandler, HttpProxyBasePlugin):
"""Proxy pool plugin simply acts as a proxy adapter for proxy.py itself.
Imagine this plugin as setting up proxy settings for proxy.py instance itself.
All incoming client requests are proxied to configured upstream proxies."""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.upstream: Optional[TcpServerConnection] = None
# Cached attributes to be used during access log override
self.request_host_port_path_method: List[Any] = [
None, None, None, None,
]
self.total_size = 0

def get_descriptors(self) -> Tuple[List[socket.socket], List[socket.socket]]:
if not self.upstream:
return [], []
return [self.upstream.connection], [self.upstream.connection] if self.upstream.has_buffer() else []

def read_from_descriptors(self, r: Readables) -> bool:
# Read from upstream proxy and queue for client
if self.upstream and self.upstream.connection in r:
try:
raw = self.upstream.recv(self.flags.server_recvbuf_size)
if raw is not None:
self.total_size += len(raw)
self.client.queue(raw)
else:
return True # Teardown because upstream proxy closed the connection
except ConnectionResetError:
logger.debug('Connection reset by upstream proxy')
return True
return False # Do not teardown connection

def write_to_descriptors(self, w: Writables) -> bool:
# Flush queued data to upstream proxy now
if self.upstream and self.upstream.connection in w and self.upstream.has_buffer():
try:
self.upstream.flush()
except BrokenPipeError:
logger.debug('BrokenPipeError when flushing to upstream proxy')
return True
return False

def handle_upstream_data(self, raw: memoryview) -> None:
self.client.queue(raw)

def before_upstream_connection(
self, request: HttpParser,
Expand All @@ -109,12 +78,14 @@ def before_upstream_connection(
# must be bootstrapped within it's own re-usable and gc'd pool, to avoid establishing
# a fresh upstream proxy connection for each client request.
#
# See :class:`~proxy.core.connection.pool.ConnectionPool` which is a work
# in progress for SSL cache handling.
#
# Implement your own logic here e.g. round-robin, least connection etc.
endpoint = random.choice(self.flags.proxy_pool)[0].split(':')
logger.debug('Using endpoint: {0}:{1}'.format(*endpoint))
self.upstream = TcpServerConnection(
endpoint[0], int(endpoint[1]),
)
self.initialize_upstream(endpoint[0], int(endpoint[1]))
assert self.upstream
try:
self.upstream.connect()
except ConnectionRefusedError:
Expand Down
Loading

0 comments on commit 3cfce52

Please sign in to comment.