Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: serve missing logic from #3321 #3336

Merged
merged 5 commits into from
Jan 11, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 43 additions & 16 deletions src/bentoml/_internal/server/server.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,58 @@
"""
Server class for getting the Bento client and managing server process
"""

from __future__ import annotations

import logging
import traceback
import subprocess
from typing import TYPE_CHECKING

import attr

from ..utils import cached_property

if TYPE_CHECKING:
from types import TracebackType


logger = logging.getLogger(__name__)

class Server:
def __init__(self, process: subprocess.Popen[bytes], host: str, port: int) -> None:
self._process = process
self._host = host
self._port = port

@attr.frozen
class ServerHandle:
process: subprocess.Popen[bytes]
host: str
port: int
timeout: int = attr.field(default=10)

@cached_property
def client(self):
return self.get_client()

def get_client(self):
from bentoml.client import Client

Client.wait_until_server_is_ready(self._host, self._port, 10)
return Client.from_url(f"http://localhost:{self._port}")
Client.wait_until_server_is_ready(
host=self.host, port=self.port, timeout=self.timeout
)
return Client.from_url(f"http://localhost:{self.port}")

def stop(self) -> None:
self.process.kill()

@property
def process(self) -> subprocess.Popen[bytes]:
return self._process

@property
def address(self) -> str:
return f"{self._host}:{self._port}"
return f"{self.host}:{self.port}"

def __enter__(self):
yield self

def __exit__(
self,
exc_type: type[BaseException],
exc_value: BaseException,
traceback_type: TracebackType,
):
try:
self.stop()
except Exception as e: # pylint: disable=broad-except
logger.error(f"Error stopping server: {e}", exc_info=e)
traceback.print_exception(exc_type, exc_value, traceback_type)
aarnphm marked this conversation as resolved.
Show resolved Hide resolved
111 changes: 61 additions & 50 deletions src/bentoml/bentos.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
from ._internal.tag import Tag
from ._internal.bento import Bento
from ._internal.utils import resolve_user_filepath
from ._internal.server.server import Server
from ._internal.bento.build_config import BentoBuildConfig
from ._internal.configuration.containers import BentoMLContainer

if TYPE_CHECKING:
from ._internal.bento import BentoStore
from ._internal.server.server import ServerHandle

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -427,76 +427,87 @@ def containerize(bento_tag: Tag | str, **kwargs: t.Any) -> bool:
@inject
def serve(
bento: str,
production: bool = False,
port: int = Provide[BentoMLContainer.http.port],
host: str = Provide[BentoMLContainer.http.host],
server_type: str = "http",
api_workers: int | None = Provide[BentoMLContainer.api_server_workers],
backlog: int = Provide[BentoMLContainer.api_server_config.backlog],
reload: bool = False,
production: bool = False,
host: str | None = None,
port: int | None = None,
working_dir: str | None = None,
ssl_certfile: str | None = None,
ssl_keyfile: str | None = None,
ssl_ca_certs: str | None = None,
# HTTP-specific args
ssl_keyfile_password: str | None = None,
ssl_version: int | None = None,
ssl_cert_reqs: int | None = None,
ssl_ciphers: str | None = None,
# GRPC-specific args
api_workers: int | None = Provide[BentoMLContainer.api_server_workers],
backlog: int = Provide[BentoMLContainer.api_server_config.backlog],
ssl_certfile: str | None = Provide[BentoMLContainer.ssl.certfile],
ssl_keyfile: str | None = Provide[BentoMLContainer.ssl.keyfile],
ssl_keyfile_password: str | None = Provide[BentoMLContainer.ssl.keyfile_password],
ssl_version: int | None = Provide[BentoMLContainer.ssl.version],
ssl_cert_reqs: int | None = Provide[BentoMLContainer.ssl.cert_reqs],
ssl_ca_certs: str | None = Provide[BentoMLContainer.ssl.ca_certs],
ssl_ciphers: str | None = Provide[BentoMLContainer.ssl.ciphers],
enable_reflection: bool = Provide[BentoMLContainer.grpc.reflection.enabled],
enable_channelz: bool = Provide[BentoMLContainer.grpc.channelz.enabled],
max_concurrent_streams: int
| None = Provide[BentoMLContainer.grpc.max_concurrent_streams],
) -> Server:
"""Launch a BentoServer and returns a client that exposes all APIs defined in target service"""
) -> ServerHandle:
from .serve import construct_ssl_args
from ._internal.server.server import ServerHandle

if server_type.lower() not in ["http", "grpc"]:
server_type = server_type.lower()
if server_type not in ["http", "grpc"]:
raise ValueError('Server type must either be "http" or "grpc"')

args = [
if production and reload:
raise ValueError("'reload' and 'production' are mutually exclusive.")
aarnphm marked this conversation as resolved.
Show resolved Hide resolved

if server_type == "http":
serve_cmd = "serve-http"
if host is None:
host = BentoMLContainer.http.host.get()
if port is None:
port = BentoMLContainer.http.port.get()
else:
serve_cmd = "serve-grpc"
if host is None:
host = BentoMLContainer.grpc.host.get()
if port is None:
port = BentoMLContainer.grpc.port.get()

assert host is not None and port is not None
args: t.List[str] = [
"-m",
"bentoml",
"serve",
serve_cmd,
bento,
"--port",
str(port),
"--host",
host,
"--port",
str(port),
"--backlog",
str(backlog),
*construct_ssl_args(
ssl_certfile=ssl_certfile,
ssl_keyfile=ssl_keyfile,
ssl_keyfile_password=ssl_keyfile_password,
ssl_version=ssl_version,
ssl_cert_reqs=ssl_cert_reqs,
ssl_ca_certs=ssl_ca_certs,
ssl_ciphers=ssl_ciphers,
),
]
if production:
args.append("--production")
if reload:
args.extend(["--reload", str(reload)])
elif reload:
args.append("--reload")
aarnphm marked this conversation as resolved.
Show resolved Hide resolved

if api_workers is not None:
args.extend(["--api-workers", str(api_workers)])
if working_dir is not None:
args.extend(["--working-dir", str(working_dir)])
if ssl_certfile is not None:
args.extend(["--ssl-certfile", ssl_certfile])
if ssl_keyfile is not None:
args.extend(["--ssl-keyfile", ssl_keyfile])
if ssl_ca_certs is not None:
args.extend(["--ssl-ca-certs", ssl_ca_certs])
if server_type.lower() == "http":
if ssl_keyfile_password is not None:
args.extend(["--ssl-keyfile-password", ssl_keyfile_password])
if ssl_version is not None:
args.extend(["--ssl-version", str(ssl_version)])
if ssl_cert_reqs is not None:
args.extend(["--ssl-cert-reqs", str(ssl_cert_reqs)])
if ssl_ciphers is not None:
args.extend(["--ssl-ciphers", ssl_ciphers])
if server_type.lower() == "grpc":
if enable_reflection:
args.extend(["--enable-reflection", str(enable_reflection)])
if enable_channelz:
args.extend(["--enable-channelz", str(enable_channelz)])
if max_concurrent_streams is not None:
args.extend(["--max-concurrent-streams", str(max_concurrent_streams)])

process = subprocess.Popen(args, executable=sys.executable)

return Server(process, host, port)
if enable_reflection:
args.append("--enable-reflection")
if enable_channelz:
args.append("--enable-channelz")
if max_concurrent_streams is not None:
args.extend(["--max-concurrent-streams", str(max_concurrent_streams)])

return ServerHandle(
process=subprocess.Popen(args, executable=sys.executable), host=host, port=port
)