Skip to content

Commit

Permalink
fix: handle backward protobuf version
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Pham <[email protected]>
  • Loading branch information
aarnphm committed Feb 16, 2023
1 parent a61379a commit b337063
Show file tree
Hide file tree
Showing 13 changed files with 320 additions and 74 deletions.
79 changes: 72 additions & 7 deletions src/bentoml/_internal/io_descriptors/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import typing as t
import logging
from typing import TYPE_CHECKING
from functools import lru_cache

from starlette.requests import Request
from multipart.multipart import parse_options_header
Expand All @@ -19,25 +20,25 @@
from ...exceptions import InvalidArgument
from ...exceptions import BentoMLException
from ...exceptions import MissingDependencyException
from ...grpc.utils import import_generated_stubs
from ..service.openapi import SUCCESS_DESCRIPTION
from ..service.openapi.specification import Schema
from ..service.openapi.specification import MediaType

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from typing_extensions import Self

from bentoml.grpc.v1 import service_pb2 as pb
from bentoml.grpc.v1alpha1 import service_pb2 as pb_v1alpha1

from .base import OpenAPIResponse
from ..context import InferenceApiContext as Context

FileKind: t.TypeAlias = t.Literal["binaryio", "textio"]
else:
from bentoml.grpc.utils import import_generated_stubs

pb, _ = import_generated_stubs()
pb, _ = import_generated_stubs("v1")
pb_v1alpha1, _ = import_generated_stubs("v1alpha1")

FileType = t.Union[io.IOBase, t.IO[bytes], FileLike[bytes]]

Expand Down Expand Up @@ -137,13 +138,19 @@ def _from_sample(self, sample: FileType | str) -> FileType:
sample = FileLike[bytes](sample, "<sample>")
elif isinstance(sample, (str, os.PathLike)):
p = resolve_user_filepath(sample, ctx=None)
self._mime_type = filetype.guess_mime(p)
try:
mime = filetype.guess_mime(p)
if mime is None:
raise ValueError(f"could not guess MIME type of file {p}")
except Exception as e:
raise BadInput(f"failed to guess MIME type of {p}: {e}")
self._mime_type = mime
with open(p, "rb") as f:
sample = FileLike[bytes](f, "<sample>")
return sample

@classmethod
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
def from_spec(cls, spec: dict[str, t.Any]) -> t.Self:
if "args" not in spec:
raise InvalidArgument(f"Missing args key in File spec: {spec}")
return cls(**spec["args"])
Expand Down Expand Up @@ -199,6 +206,21 @@ async def to_proto(self, obj: FileType) -> pb.File:

return pb.File(kind=self._mime_type, content=body)

async def to_proto_v1alpha1(self, obj: FileType) -> pb_v1alpha1.File:
if isinstance(obj, bytes):
body = obj
else:
body = obj.read()

try:
kind = mimetype_to_filetype_pb_map()[self._mime_type]
except KeyError:
raise BadInput(
f"{self._mime_type} doesn't have a corresponding File 'kind'"
) from None

return pb_v1alpha1.File(kind=kind, content=body)

async def from_proto(self, field: pb.File | bytes) -> FileLike[bytes]:
raise NotImplementedError

Expand Down Expand Up @@ -243,10 +265,29 @@ async def from_http_request(self, request: Request) -> FileLike[bytes]:
f"File should have Content-Type '{self._mime_type}' or 'multipart/form-data', got {content_type} instead"
)

async def from_proto(self, field: pb.File | bytes) -> FileLike[bytes]:
async def from_proto(
self, field: pb.File | pb_v1alpha1.File | bytes
) -> FileLike[bytes]:
# check if the request message has the correct field
if isinstance(field, bytes):
content = field
elif isinstance(field, pb_v1alpha1.File):
mapping = filetype_pb_to_mimetype_map()
if field.kind:
try:
mime_type = mapping[field.kind]
if mime_type != self._mime_type:
raise BadInput(
f"Inferred mime_type from 'kind' is '{mime_type}', while '{self!r}' is expecting '{self._mime_type}'",
)
except KeyError:
raise BadInput(
f"{field.kind} is not a valid File kind. Accepted file kind: {[names for names,_ in pb_v1alpha1.File.FileType.items()]}",
) from None
content = field.content
if not content:
raise BadInput("Content is empty!") from None
return FileLike[bytes](io.BytesIO(content), "<content>")
else:
assert isinstance(field, pb.File)
if field.kind and field.kind != self._mime_type:
Expand All @@ -258,3 +299,27 @@ async def from_proto(self, field: pb.File | bytes) -> FileLike[bytes]:
raise BadInput("Content is empty!") from None

return FileLike[bytes](io.BytesIO(content), "<content>")


# v1alpha1 backward compatibility
@lru_cache(maxsize=1)
def filetype_pb_to_mimetype_map() -> dict[pb_v1alpha1.File.FileType.ValueType, str]:
return {
pb_v1alpha1.File.FILE_TYPE_CSV: "text/csv",
pb_v1alpha1.File.FILE_TYPE_PLAINTEXT: "text/plain",
pb_v1alpha1.File.FILE_TYPE_JSON: "application/json",
pb_v1alpha1.File.FILE_TYPE_BYTES: "application/octet-stream",
pb_v1alpha1.File.FILE_TYPE_PDF: "application/pdf",
pb_v1alpha1.File.FILE_TYPE_PNG: "image/png",
pb_v1alpha1.File.FILE_TYPE_JPEG: "image/jpeg",
pb_v1alpha1.File.FILE_TYPE_GIF: "image/gif",
pb_v1alpha1.File.FILE_TYPE_TIFF: "image/tiff",
pb_v1alpha1.File.FILE_TYPE_BMP: "image/bmp",
pb_v1alpha1.File.FILE_TYPE_WEBP: "image/webp",
pb_v1alpha1.File.FILE_TYPE_SVG: "image/svg+xml",
}


@lru_cache(maxsize=1)
def mimetype_to_filetype_pb_map() -> dict[str, pb_v1alpha1.File.FileType.ValueType]:
return {v: k for k, v in filetype_pb_to_mimetype_map().items()}
56 changes: 48 additions & 8 deletions src/bentoml/_internal/io_descriptors/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ...exceptions import InvalidArgument
from ...exceptions import InternalServerError
from ...exceptions import MissingDependencyException
from ...grpc.utils import import_generated_stubs
from ..service.openapi import SUCCESS_DESCRIPTION
from ..service.openapi.specification import Schema
from ..service.openapi.specification import MediaType
Expand All @@ -31,26 +32,24 @@

import PIL
import PIL.Image
from typing_extensions import Self

from bentoml.grpc.v1 import service_pb2 as pb

from .. import external_typing as ext
from .base import OpenAPIResponse
from ..context import InferenceApiContext as Context
from ...grpc.v1 import service_pb2 as pb
from ...grpc.v1alpha1 import service_pb2 as pb_v1alpha1

_Mode = t.Literal[
"1", "CMYK", "F", "HSV", "I", "L", "LAB", "P", "RGB", "RGBA", "RGBX", "YCbCr"
]
else:
from bentoml.grpc.utils import import_generated_stubs

# NOTE: pillow-simd only benefits users who want to do preprocessing
# TODO: add options for users to choose between simd and native mode
PIL = LazyLoader("PIL", globals(), "PIL", exc_msg=PIL_EXC_MSG)
PIL.Image = LazyLoader("PIL.Image", globals(), "PIL.Image", exc_msg=PIL_EXC_MSG)

pb, _ = import_generated_stubs()
pb, _ = import_generated_stubs("v1")
pb_v1alpha1, _ = import_generated_stubs("v1alpha1")

# NOTES: we will keep type in quotation to avoid backward compatibility
# with numpy < 1.20, since we will use the latest stubs from the main branch of numpy.
Expand Down Expand Up @@ -245,7 +244,7 @@ def to_spec(self) -> dict[str, t.Any]:
}

@classmethod
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
def from_spec(cls, spec: dict[str, t.Any]) -> t.Self:
if "args" not in spec:
raise InvalidArgument(f"Missing args key in Image spec: {spec}")

Expand Down Expand Up @@ -387,9 +386,27 @@ async def to_http_response(
headers={"content-disposition": content_disposition},
)

async def from_proto(self, field: pb.File | bytes) -> ImageType:
async def from_proto(self, field: pb.File | pb_v1alpha1.File | bytes) -> ImageType:
if isinstance(field, bytes):
content = field
elif isinstance(field, pb_v1alpha1.File):
from .file import filetype_pb_to_mimetype_map

mapping = filetype_pb_to_mimetype_map()
if field.kind:
try:
mime_type = mapping[field.kind]
if mime_type != self._mime_type:
raise BadInput(
f"Inferred mime_type from 'kind' is '{mime_type}', while '{self!r}' is expecting '{self._mime_type}'",
)
except KeyError:
raise BadInput(
f"{field.kind} is not a valid File kind. Accepted file kind: {[names for names,_ in pb_v1alpha1.File.FileType.items()]}",
) from None
if not field.content:
raise BadInput("Content is empty!") from None
return PIL.Image.open(io.BytesIO(field.content))
else:
assert isinstance(field, pb.File)
if field.kind and field.kind != self._mime_type:
Expand All @@ -402,6 +419,29 @@ async def from_proto(self, field: pb.File | bytes) -> ImageType:

return PIL.Image.open(io.BytesIO(content))

async def to_proto_v1alpha1(self, obj: ImageType) -> pb_v1alpha1.File:
from .file import mimetype_to_filetype_pb_map

try:
kind = mimetype_to_filetype_pb_map()[self._mime_type]
except KeyError:
raise BadInput(
f"{self._mime_type} doesn't have a corresponding File 'kind'"
) from None

if LazyType["ext.NpNDArray"]("numpy.ndarray").isinstance(obj):
image = PIL.Image.fromarray(obj, mode=self._pilmode)
elif LazyType["PIL.Image.Image"]("PIL.Image.Image").isinstance(obj):
image = obj
else:
raise BadInput(
f"Unsupported Image type received: '{type(obj)}', the Image IO descriptor only supports 'np.ndarray' and 'PIL.Image'.",
) from None
ret = io.BytesIO()
image.save(ret, format=self._format)

return pb_v1alpha1.File(kind=kind, content=ret.getvalue())

async def to_proto(self, obj: ImageType) -> pb.File:
if LazyType["ext.NpNDArray"]("numpy.ndarray").isinstance(obj):
image = PIL.Image.fromarray(obj, mode=self._pilmode)
Expand Down
35 changes: 29 additions & 6 deletions src/bentoml/_internal/io_descriptors/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .base import IODescriptor
from ...exceptions import InvalidArgument
from ...exceptions import BentoMLException
from ...grpc.utils import import_generated_stubs
from ..service.openapi import SUCCESS_DESCRIPTION
from ..utils.formparser import populate_multipart_requests
from ..utils.formparser import concat_to_multipart_response
Expand All @@ -21,17 +22,17 @@
if TYPE_CHECKING:
from types import UnionType

from typing_extensions import Self
from google.protobuf import message as _message

from bentoml.grpc.v1 import service_pb2 as pb
from bentoml.grpc.v1alpha1 import service_pb2 as pb_v1alpha1

from .base import OpenAPIResponse
from ..types import LazyType
from ..context import InferenceApiContext as Context
else:
from bentoml.grpc.utils import import_generated_stubs

pb, _ = import_generated_stubs()
pb, _ = import_generated_stubs("v1")
pb_v1alpha1, _ = import_generated_stubs("v1alpha1")


class Multipart(IODescriptor[t.Dict[str, t.Any]], descriptor_id="bentoml.io.Multipart"):
Expand Down Expand Up @@ -202,7 +203,7 @@ def to_spec(self) -> dict[str, t.Any]:
}

@classmethod
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
def from_spec(cls, spec: dict[str, t.Any]) -> t.Self:
if "args" not in spec:
raise InvalidArgument(f"Missing args key in Multipart spec: {spec}")
return Multipart(
Expand Down Expand Up @@ -321,7 +322,23 @@ async def from_proto(self, field: pb.Multipart) -> dict[str, t.Any]:
)
return dict(zip(self._inputs.keys(), reqs))

async def to_proto(self, obj: dict[str, t.Any]) -> pb.Multipart:
@t.overload
async def _to_proto_impl(
self, obj: dict[str, t.Any], *, version: t.Literal["v1"]
) -> pb.Multipart:
...

@t.overload
async def _to_proto_impl(
self, obj: dict[str, t.Any], *, version: t.Literal["v1alpha1"]
) -> pb_v1alpha1.Multipart:
...

async def _to_proto_impl(
self, obj: dict[str, t.Any], *, version: str
) -> _message.Message:
pb, _ = import_generated_stubs(version)

self.validate_input_mapping(obj)
resps = await asyncio.gather(
*tuple(
Expand All @@ -341,3 +358,9 @@ async def to_proto(self, obj: dict[str, t.Any]) -> pb.Multipart:
)
)
)

async def to_proto(self, obj: dict[str, t.Any]) -> pb.Multipart:
return await self._to_proto_impl(obj, version="v1")

async def to_proto_v1alpha1(self, obj: dict[str, t.Any]) -> pb_v1alpha1.Multipart:
return await self._to_proto_impl(obj, version="v1alpha1")
Loading

0 comments on commit b337063

Please sign in to comment.