Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(grpc): handle backward protocol version #3332

Merged
merged 1 commit into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 72 additions & 7 deletions src/bentoml/_internal/io_descriptors/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import typing as t
import logging
from typing import TYPE_CHECKING
from functools import lru_cache

from starlette.requests import Request
from multipart.multipart import parse_options_header
Expand All @@ -19,25 +20,25 @@
from ...exceptions import InvalidArgument
from ...exceptions import BentoMLException
from ...exceptions import MissingDependencyException
from ...grpc.utils import import_generated_stubs
aarnphm marked this conversation as resolved.
Show resolved Hide resolved
from ..service.openapi import SUCCESS_DESCRIPTION
from ..service.openapi.specification import Schema
from ..service.openapi.specification import MediaType

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from typing_extensions import Self

from bentoml.grpc.v1 import service_pb2 as pb
from bentoml.grpc.v1alpha1 import service_pb2 as pb_v1alpha1

from .base import OpenAPIResponse
from ..context import InferenceApiContext as Context

FileKind: t.TypeAlias = t.Literal["binaryio", "textio"]
else:
from bentoml.grpc.utils import import_generated_stubs

pb, _ = import_generated_stubs()
pb, _ = import_generated_stubs("v1")
pb_v1alpha1, _ = import_generated_stubs("v1alpha1")

FileType = t.Union[io.IOBase, t.IO[bytes], FileLike[bytes]]

Expand Down Expand Up @@ -137,13 +138,19 @@ def _from_sample(self, sample: FileType | str) -> FileType:
sample = FileLike[bytes](sample, "<sample>")
elif isinstance(sample, (str, os.PathLike)):
p = resolve_user_filepath(sample, ctx=None)
self._mime_type = filetype.guess_mime(p)
try:
mime = filetype.guess_mime(p)
if mime is None:
raise ValueError(f"could not guess MIME type of file {p}")
except Exception as e:
raise BadInput(f"failed to guess MIME type of {p}: {e}")
self._mime_type = mime
with open(p, "rb") as f:
sample = FileLike[bytes](f, "<sample>")
return sample

@classmethod
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
def from_spec(cls, spec: dict[str, t.Any]) -> t.Self:
if "args" not in spec:
raise InvalidArgument(f"Missing args key in File spec: {spec}")
return cls(**spec["args"])
Expand Down Expand Up @@ -199,6 +206,21 @@ async def to_proto(self, obj: FileType) -> pb.File:

return pb.File(kind=self._mime_type, content=body)

async def to_proto_v1alpha1(self, obj: FileType) -> pb_v1alpha1.File:
if isinstance(obj, bytes):
body = obj
else:
body = obj.read()

try:
kind = mimetype_to_filetype_pb_map()[self._mime_type]
except KeyError:
raise BadInput(
f"{self._mime_type} doesn't have a corresponding File 'kind'"
) from None

return pb_v1alpha1.File(kind=kind, content=body)

async def from_proto(self, field: pb.File | bytes) -> FileLike[bytes]:
raise NotImplementedError

Expand Down Expand Up @@ -243,10 +265,29 @@ async def from_http_request(self, request: Request) -> FileLike[bytes]:
f"File should have Content-Type '{self._mime_type}' or 'multipart/form-data', got {content_type} instead"
)

async def from_proto(self, field: pb.File | bytes) -> FileLike[bytes]:
async def from_proto(
self, field: pb.File | pb_v1alpha1.File | bytes
) -> FileLike[bytes]:
# check if the request message has the correct field
if isinstance(field, bytes):
content = field
elif isinstance(field, pb_v1alpha1.File):
mapping = filetype_pb_to_mimetype_map()
if field.kind:
try:
mime_type = mapping[field.kind]
if mime_type != self._mime_type:
raise BadInput(
f"Inferred mime_type from 'kind' is '{mime_type}', while '{self!r}' is expecting '{self._mime_type}'",
)
except KeyError:
raise BadInput(
f"{field.kind} is not a valid File kind. Accepted file kind: {[names for names,_ in pb_v1alpha1.File.FileType.items()]}",
) from None
content = field.content
if not content:
raise BadInput("Content is empty!") from None
return FileLike[bytes](io.BytesIO(content), "<content>")
else:
assert isinstance(field, pb.File)
if field.kind and field.kind != self._mime_type:
Expand All @@ -258,3 +299,27 @@ async def from_proto(self, field: pb.File | bytes) -> FileLike[bytes]:
raise BadInput("Content is empty!") from None

return FileLike[bytes](io.BytesIO(content), "<content>")


# v1alpha1 backward compatibility
@lru_cache(maxsize=1)
def filetype_pb_to_mimetype_map() -> dict[pb_v1alpha1.File.FileType.ValueType, str]:
return {
pb_v1alpha1.File.FILE_TYPE_CSV: "text/csv",
pb_v1alpha1.File.FILE_TYPE_PLAINTEXT: "text/plain",
pb_v1alpha1.File.FILE_TYPE_JSON: "application/json",
pb_v1alpha1.File.FILE_TYPE_BYTES: "application/octet-stream",
pb_v1alpha1.File.FILE_TYPE_PDF: "application/pdf",
pb_v1alpha1.File.FILE_TYPE_PNG: "image/png",
pb_v1alpha1.File.FILE_TYPE_JPEG: "image/jpeg",
pb_v1alpha1.File.FILE_TYPE_GIF: "image/gif",
pb_v1alpha1.File.FILE_TYPE_TIFF: "image/tiff",
pb_v1alpha1.File.FILE_TYPE_BMP: "image/bmp",
pb_v1alpha1.File.FILE_TYPE_WEBP: "image/webp",
pb_v1alpha1.File.FILE_TYPE_SVG: "image/svg+xml",
}


@lru_cache(maxsize=1)
def mimetype_to_filetype_pb_map() -> dict[str, pb_v1alpha1.File.FileType.ValueType]:
return {v: k for k, v in filetype_pb_to_mimetype_map().items()}
56 changes: 48 additions & 8 deletions src/bentoml/_internal/io_descriptors/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ...exceptions import InvalidArgument
from ...exceptions import InternalServerError
from ...exceptions import MissingDependencyException
from ...grpc.utils import import_generated_stubs
from ..service.openapi import SUCCESS_DESCRIPTION
from ..service.openapi.specification import Schema
from ..service.openapi.specification import MediaType
Expand All @@ -31,26 +32,24 @@

import PIL
import PIL.Image
from typing_extensions import Self

from bentoml.grpc.v1 import service_pb2 as pb

from .. import external_typing as ext
from .base import OpenAPIResponse
from ..context import InferenceApiContext as Context
from ...grpc.v1 import service_pb2 as pb
from ...grpc.v1alpha1 import service_pb2 as pb_v1alpha1

_Mode = t.Literal[
"1", "CMYK", "F", "HSV", "I", "L", "LAB", "P", "RGB", "RGBA", "RGBX", "YCbCr"
]
else:
from bentoml.grpc.utils import import_generated_stubs

# NOTE: pillow-simd only benefits users who want to do preprocessing
# TODO: add options for users to choose between simd and native mode
PIL = LazyLoader("PIL", globals(), "PIL", exc_msg=PIL_EXC_MSG)
PIL.Image = LazyLoader("PIL.Image", globals(), "PIL.Image", exc_msg=PIL_EXC_MSG)

pb, _ = import_generated_stubs()
pb, _ = import_generated_stubs("v1")
pb_v1alpha1, _ = import_generated_stubs("v1alpha1")

# NOTES: we will keep type in quotation to avoid backward compatibility
# with numpy < 1.20, since we will use the latest stubs from the main branch of numpy.
Expand Down Expand Up @@ -245,7 +244,7 @@ def to_spec(self) -> dict[str, t.Any]:
}

@classmethod
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
def from_spec(cls, spec: dict[str, t.Any]) -> t.Self:
if "args" not in spec:
raise InvalidArgument(f"Missing args key in Image spec: {spec}")

Expand Down Expand Up @@ -387,9 +386,27 @@ async def to_http_response(
headers={"content-disposition": content_disposition},
)

async def from_proto(self, field: pb.File | bytes) -> ImageType:
async def from_proto(self, field: pb.File | pb_v1alpha1.File | bytes) -> ImageType:
if isinstance(field, bytes):
content = field
elif isinstance(field, pb_v1alpha1.File):
from .file import filetype_pb_to_mimetype_map

mapping = filetype_pb_to_mimetype_map()
if field.kind:
try:
mime_type = mapping[field.kind]
if mime_type != self._mime_type:
raise BadInput(
f"Inferred mime_type from 'kind' is '{mime_type}', while '{self!r}' is expecting '{self._mime_type}'",
)
except KeyError:
raise BadInput(
f"{field.kind} is not a valid File kind. Accepted file kind: {[names for names,_ in pb_v1alpha1.File.FileType.items()]}",
) from None
if not field.content:
raise BadInput("Content is empty!") from None
return PIL.Image.open(io.BytesIO(field.content))
else:
assert isinstance(field, pb.File)
if field.kind and field.kind != self._mime_type:
Expand All @@ -402,6 +419,29 @@ async def from_proto(self, field: pb.File | bytes) -> ImageType:

return PIL.Image.open(io.BytesIO(content))

async def to_proto_v1alpha1(self, obj: ImageType) -> pb_v1alpha1.File:
from .file import mimetype_to_filetype_pb_map

try:
kind = mimetype_to_filetype_pb_map()[self._mime_type]
except KeyError:
raise BadInput(
f"{self._mime_type} doesn't have a corresponding File 'kind'"
) from None

if LazyType["ext.NpNDArray"]("numpy.ndarray").isinstance(obj):
image = PIL.Image.fromarray(obj, mode=self._pilmode)
elif LazyType["PIL.Image.Image"]("PIL.Image.Image").isinstance(obj):
image = obj
else:
raise BadInput(
f"Unsupported Image type received: '{type(obj)}', the Image IO descriptor only supports 'np.ndarray' and 'PIL.Image'.",
) from None
ret = io.BytesIO()
image.save(ret, format=self._format)

return pb_v1alpha1.File(kind=kind, content=ret.getvalue())

async def to_proto(self, obj: ImageType) -> pb.File:
if LazyType["ext.NpNDArray"]("numpy.ndarray").isinstance(obj):
image = PIL.Image.fromarray(obj, mode=self._pilmode)
Expand Down
35 changes: 29 additions & 6 deletions src/bentoml/_internal/io_descriptors/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .base import IODescriptor
from ...exceptions import InvalidArgument
from ...exceptions import BentoMLException
from ...grpc.utils import import_generated_stubs
from ..service.openapi import SUCCESS_DESCRIPTION
from ..utils.formparser import populate_multipart_requests
from ..utils.formparser import concat_to_multipart_response
Expand All @@ -21,17 +22,17 @@
if TYPE_CHECKING:
from types import UnionType

from typing_extensions import Self
from google.protobuf import message as _message

from bentoml.grpc.v1 import service_pb2 as pb
from bentoml.grpc.v1alpha1 import service_pb2 as pb_v1alpha1

from .base import OpenAPIResponse
from ..types import LazyType
from ..context import InferenceApiContext as Context
else:
from bentoml.grpc.utils import import_generated_stubs

pb, _ = import_generated_stubs()
pb, _ = import_generated_stubs("v1")
pb_v1alpha1, _ = import_generated_stubs("v1alpha1")


class Multipart(IODescriptor[t.Dict[str, t.Any]], descriptor_id="bentoml.io.Multipart"):
Expand Down Expand Up @@ -202,7 +203,7 @@ def to_spec(self) -> dict[str, t.Any]:
}

@classmethod
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
def from_spec(cls, spec: dict[str, t.Any]) -> t.Self:
if "args" not in spec:
raise InvalidArgument(f"Missing args key in Multipart spec: {spec}")
return Multipart(
Expand Down Expand Up @@ -321,7 +322,23 @@ async def from_proto(self, field: pb.Multipart) -> dict[str, t.Any]:
)
return dict(zip(self._inputs.keys(), reqs))

async def to_proto(self, obj: dict[str, t.Any]) -> pb.Multipart:
@t.overload
async def _to_proto_impl(
self, obj: dict[str, t.Any], *, version: t.Literal["v1"]
) -> pb.Multipart:
...

@t.overload
async def _to_proto_impl(
self, obj: dict[str, t.Any], *, version: t.Literal["v1alpha1"]
) -> pb_v1alpha1.Multipart:
...

async def _to_proto_impl(
self, obj: dict[str, t.Any], *, version: str
) -> _message.Message:
pb, _ = import_generated_stubs(version)

self.validate_input_mapping(obj)
resps = await asyncio.gather(
*tuple(
Expand All @@ -341,3 +358,9 @@ async def to_proto(self, obj: dict[str, t.Any]) -> pb.Multipart:
)
)
)

async def to_proto(self, obj: dict[str, t.Any]) -> pb.Multipart:
return await self._to_proto_impl(obj, version="v1")

async def to_proto_v1alpha1(self, obj: dict[str, t.Any]) -> pb_v1alpha1.Multipart:
return await self._to_proto_impl(obj, version="v1alpha1")
Loading