Skip to content

Commit

Permalink
fix(server): include system headers (#4418)
Browse files Browse the repository at this point in the history
* fix(server): include system headers

* fix

* fix

* opt
  • Loading branch information
bojiang authored Jan 18, 2024
1 parent 7029600 commit 84edd85
Showing 1 changed file with 24 additions and 10 deletions.
34 changes: 24 additions & 10 deletions src/_bentoml_impl/client/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from pydantic import RootModel

from _bentoml_sdk import IODescriptor
from _bentoml_sdk.typing_utils import is_file_like
from _bentoml_sdk.typing_utils import is_image_type
from bentoml import __version__
from bentoml._internal.utils.uri import uri_to_path
Expand All @@ -43,10 +42,6 @@
MAX_RETRIES = 3


def _is_file(obj: t.Any) -> bool:
return isinstance(obj, pathlib.PurePath) or is_file_like(obj)


@attr.define
class HTTPClient(AbstractClient, t.Generic[C]):
client_cls: t.ClassVar[type[BaseClient]]
Expand All @@ -56,6 +51,8 @@ class HTTPClient(AbstractClient, t.Generic[C]):
media_type: str = "application/json"
timeout: int = 30
token: str | None = None
const_headers: dict[str, str] = attr.field(factory=dict)

_opened_files: list[io.BufferedReader] = attr.field(init=False, factory=list)
_temp_dir: tempfile.TemporaryDirectory[str] = attr.field(init=False)

Expand Down Expand Up @@ -84,7 +81,7 @@ def client(self) -> C:
),
)

@_temp_dir.default
@_temp_dir.default # type: ignore
def default_temp_dir(self) -> tempfile.TemporaryDirectory[str]:
return tempfile.TemporaryDirectory(prefix="bentoml-client-")

Expand Down Expand Up @@ -129,6 +126,7 @@ def __init__(
doc=route.get("doc"),
stream_output=route["output"].get("is_stream", False),
)
const_headers = {}
else:
for name, method in service.apis.items():
routes[name] = ClientEndpoint(
Expand All @@ -142,8 +140,22 @@ def __init__(
stream_output=method.is_stream,
)

self.__attrs_init__(
url=url, endpoints=routes, media_type=media_type, token=token
from bentoml._internal.context import component_context

const_headers = {
"Bento-Name": component_context.bento_name,
"Bento-Version": component_context.bento_version,
"Runner-Name": service.name,
"Yatai-Bento-Deployment-Name": component_context.yatai_bento_deployment_name,
"Yatai-Bento-Deployment-Namespace": component_context.yatai_bento_deployment_namespace,
}

self.__attrs_init__( # type: ignore
url=url,
endpoints=routes,
media_type=media_type,
token=token,
const_headers=const_headers,
)
super().__init__()

Expand All @@ -153,11 +165,13 @@ def serde(self) -> Serde:

return ALL_SERDE[self.media_type]()

@property
@cached_property
def default_headers(self) -> dict[str, str]:
headers = {"User-Agent": f"BentoML HTTP Client/{__version__}"}
headers = self.const_headers.copy()
headers["User-Agent"] = f"BentoML HTTP Client/{__version__}"
if self.token:
headers["Authorization"] = f"Bearer {self.token}"

return headers

def _build_request(
Expand Down

0 comments on commit 84edd85

Please sign in to comment.