Skip to content

Commit

Permalink
fix: serve missing logic from bentoml#3321
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Pham <[email protected]>
  • Loading branch information
aarnphm committed Dec 9, 2022
1 parent d7254c0 commit c760fb1
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 66 deletions.
59 changes: 43 additions & 16 deletions src/bentoml/_internal/server/server.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,58 @@
"""
Server class for getting the Bento client and managing server process
"""

from __future__ import annotations

import logging
import traceback
import subprocess
from typing import TYPE_CHECKING

import attr

from ..utils import cached_property

if TYPE_CHECKING:
from types import TracebackType


logger = logging.getLogger(__name__)

class Server:
def __init__(self, process: subprocess.Popen[bytes], host: str, port: int) -> None:
self._process = process
self._host = host
self._port = port

@attr.frozen
class ServerHandle:
process: subprocess.Popen[bytes]
host: str
port: int
timeout: int = attr.field(default=10)

@cached_property
def client(self):
return self.get_client()

def get_client(self):
from bentoml.client import Client

Client.wait_until_server_is_ready(self._host, self._port, 10)
return Client.from_url(f"http://localhost:{self._port}")
Client.wait_until_server_is_ready(
host=self.host, port=self.port, timeout=self.timeout
)
return Client.from_url(f"http://localhost:{self.port}")

def stop(self) -> None:
self.process.kill()

@property
def process(self) -> subprocess.Popen[bytes]:
return self._process

@property
def address(self) -> str:
return f"{self._host}:{self._port}"
return f"{self.host}:{self.port}"

def __enter__(self):
yield self

def __exit__(
self,
exc_type: type[BaseException],
exc_value: BaseException,
traceback_type: TracebackType,
):
try:
self.stop()
except Exception as e: # pylint: disable=broad-except
logger.error(f"Error stopping server: {e}", exc_info=e)
traceback.print_exception(exc_type, exc_value, traceback_type)
111 changes: 61 additions & 50 deletions src/bentoml/bentos.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
from ._internal.tag import Tag
from ._internal.bento import Bento
from ._internal.utils import resolve_user_filepath
from ._internal.server.server import Server
from ._internal.bento.build_config import BentoBuildConfig
from ._internal.configuration.containers import BentoMLContainer

if TYPE_CHECKING:
from ._internal.bento import BentoStore
from ._internal.server.server import ServerHandle

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -427,76 +427,87 @@ def containerize(bento_tag: Tag | str, **kwargs: t.Any) -> bool:
@inject
def serve(
bento: str,
production: bool = False,
port: int = Provide[BentoMLContainer.http.port],
host: str = Provide[BentoMLContainer.http.host],
server_type: str = "http",
api_workers: int | None = Provide[BentoMLContainer.api_server_workers],
backlog: int = Provide[BentoMLContainer.api_server_config.backlog],
reload: bool = False,
production: bool = False,
host: str | None = None,
port: int | None = None,
working_dir: str | None = None,
ssl_certfile: str | None = None,
ssl_keyfile: str | None = None,
ssl_ca_certs: str | None = None,
# HTTP-specific args
ssl_keyfile_password: str | None = None,
ssl_version: int | None = None,
ssl_cert_reqs: int | None = None,
ssl_ciphers: str | None = None,
# GRPC-specific args
api_workers: int | None = Provide[BentoMLContainer.api_server_workers],
backlog: int = Provide[BentoMLContainer.api_server_config.backlog],
ssl_certfile: str | None = Provide[BentoMLContainer.ssl.certfile],
ssl_keyfile: str | None = Provide[BentoMLContainer.ssl.keyfile],
ssl_keyfile_password: str | None = Provide[BentoMLContainer.ssl.keyfile_password],
ssl_version: int | None = Provide[BentoMLContainer.ssl.version],
ssl_cert_reqs: int | None = Provide[BentoMLContainer.ssl.cert_reqs],
ssl_ca_certs: str | None = Provide[BentoMLContainer.ssl.ca_certs],
ssl_ciphers: str | None = Provide[BentoMLContainer.ssl.ciphers],
enable_reflection: bool = Provide[BentoMLContainer.grpc.reflection.enabled],
enable_channelz: bool = Provide[BentoMLContainer.grpc.channelz.enabled],
max_concurrent_streams: int
| None = Provide[BentoMLContainer.grpc.max_concurrent_streams],
) -> Server:
"""Launch a BentoServer and returns a client that exposes all APIs defined in target service"""
) -> ServerHandle:
from .serve import construct_ssl_args
from ._internal.server.server import ServerHandle

if server_type.lower() not in ["http", "grpc"]:
server_type = server_type.lower()
if server_type not in ["http", "grpc"]:
raise ValueError('Server type must either be "http" or "grpc"')

args = [
if production and reload:
raise ValueError("'reload' and 'production' are mutually exclusive.")

if server_type == "http":
serve_cmd = "serve-http"
if host is None:
host = BentoMLContainer.http.host.get()
if port is None:
port = BentoMLContainer.http.port.get()
else:
serve_cmd = "serve-grpc"
if host is None:
host = BentoMLContainer.grpc.host.get()
if port is None:
port = BentoMLContainer.grpc.port.get()

assert host is not None and port is not None
args: t.List[str] = [
"-m",
"bentoml",
"serve",
serve_cmd,
bento,
"--port",
str(port),
"--host",
host,
"--port",
str(port),
"--backlog",
str(backlog),
*construct_ssl_args(
ssl_certfile=ssl_certfile,
ssl_keyfile=ssl_keyfile,
ssl_keyfile_password=ssl_keyfile_password,
ssl_version=ssl_version,
ssl_cert_reqs=ssl_cert_reqs,
ssl_ca_certs=ssl_ca_certs,
ssl_ciphers=ssl_ciphers,
),
]
if production:
args.append("--production")
if reload:
args.extend(["--reload", str(reload)])
elif reload:
args.append("--reload")

if api_workers is not None:
args.extend(["--api-workers", str(api_workers)])
if working_dir is not None:
args.extend(["--working-dir", str(working_dir)])
if ssl_certfile is not None:
args.extend(["--ssl-certfile", ssl_certfile])
if ssl_keyfile is not None:
args.extend(["--ssl-keyfile", ssl_keyfile])
if ssl_ca_certs is not None:
args.extend(["--ssl-ca-certs", ssl_ca_certs])
if server_type.lower() == "http":
if ssl_keyfile_password is not None:
args.extend(["--ssl-keyfile-password", ssl_keyfile_password])
if ssl_version is not None:
args.extend(["--ssl-version", str(ssl_version)])
if ssl_cert_reqs is not None:
args.extend(["--ssl-cert-reqs", str(ssl_cert_reqs)])
if ssl_ciphers is not None:
args.extend(["--ssl-ciphers", ssl_ciphers])
if server_type.lower() == "grpc":
if enable_reflection:
args.extend(["--enable-reflection", str(enable_reflection)])
if enable_channelz:
args.extend(["--enable-channelz", str(enable_channelz)])
if max_concurrent_streams is not None:
args.extend(["--max-concurrent-streams", str(max_concurrent_streams)])

process = subprocess.Popen(args, executable=sys.executable)

return Server(process, host, port)
if enable_reflection:
args.append("--enable-reflection")
if enable_channelz:
args.append("--enable-channelz")
if max_concurrent_streams is not None:
args.extend(["--max-concurrent-streams", str(max_concurrent_streams)])

return ServerHandle(
process=subprocess.Popen(args, executable=sys.executable), host=host, port=port
)

0 comments on commit c760fb1

Please sign in to comment.