Skip to content

Commit

Permalink
fix: make sure the resulted proto used the corrected stubs version
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Pham <[email protected]>
  • Loading branch information
aarnphm committed Dec 16, 2022
1 parent c000445 commit 4aa20e0
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 25 deletions.
32 changes: 28 additions & 4 deletions src/bentoml/_internal/io_descriptors/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,17 @@
if TYPE_CHECKING:
from types import UnionType

from typing_extensions import Self
from google.protobuf import message as _message

from bentoml.grpc.v1 import service_pb2 as pb
from bentoml.grpc.v1alpha1 import service_pb2 as pb_v1alpha1

from .base import OpenAPIResponse
from ..types import LazyType
from ..context import InferenceApiContext as Context
else:
pb, _ = import_generated_stubs()
pb, _ = import_generated_stubs("v1")
pb_v1alpha1, _ = import_generated_stubs("v1alpha1")


class Multipart(IODescriptor[t.Dict[str, t.Any]], descriptor_id="bentoml.io.Multipart"):
Expand Down Expand Up @@ -201,7 +203,7 @@ def to_spec(self) -> dict[str, t.Any]:
}

@classmethod
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
def from_spec(cls, spec: dict[str, t.Any]) -> t.Self:
if "args" not in spec:
raise InvalidArgument(f"Missing args key in Multipart spec: {spec}")
return Multipart(
Expand Down Expand Up @@ -313,7 +315,23 @@ async def from_proto(self, field: pb.Multipart) -> dict[str, t.Any]:
)
return dict(zip(self._inputs.keys(), reqs))

async def to_proto(self, obj: dict[str, t.Any]) -> pb.Multipart:
@t.overload
async def _to_proto_impl(
self, obj: dict[str, t.Any], *, version: t.Literal["v1"]
) -> pb.Multipart:
...

@t.overload
async def _to_proto_impl(
self, obj: dict[str, t.Any], *, version: t.Literal["v1alpha1"]
) -> pb_v1alpha1.Multipart:
...

async def _to_proto_impl(
self, obj: dict[str, t.Any], *, version: str
) -> _message.Message:
pb, _ = import_generated_stubs(version)

self.validate_input_mapping(obj)
resps = await asyncio.gather(
*tuple(
Expand All @@ -333,3 +351,9 @@ async def to_proto(self, obj: dict[str, t.Any]) -> pb.Multipart:
)
)
)

async def to_proto(self, obj: dict[str, t.Any]) -> pb.Multipart:
return await self._to_proto_impl(obj, version="v1")

async def to_proto_v1alpha1(self, obj: dict[str, t.Any]) -> pb_v1alpha1.Multipart:
return await self._to_proto_impl(obj, version="v1alpha1")
116 changes: 101 additions & 15 deletions src/bentoml/_internal/io_descriptors/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,24 @@
from ...exceptions import BentoMLException
from ...exceptions import UnprocessableEntity
from ...grpc.utils import import_generated_stubs
from ...grpc.utils import LATEST_PROTOCOL_VERSION
from ..service.openapi import SUCCESS_DESCRIPTION
from ..service.openapi.specification import Schema
from ..service.openapi.specification import MediaType

if TYPE_CHECKING:
import numpy as np
from typing_extensions import Self
from google.protobuf import message as _message

from bentoml.grpc.v1 import service_pb2 as pb
from bentoml.grpc.v1alpha1 import service_pb2 as pb_v1alpha1

from .. import external_typing as ext
from .base import OpenAPIResponse
from ..context import InferenceApiContext as Context
else:
pb, _ = import_generated_stubs()
pb, _ = import_generated_stubs("v1")
pb_v1alpha1, _ = import_generated_stubs("v1alpha1")
np = LazyLoader("np", globals(), "numpy")

logger = logging.getLogger(__name__)
Expand All @@ -59,8 +62,25 @@
}


@lru_cache(maxsize=1)
def dtypepb_to_npdtype_map() -> dict[pb.NDArray.DType.ValueType, ext.NpDTypeLike]:
@t.overload
def dtypepb_to_npdtype_map(
version: t.Literal["v1"] = ...,
) -> dict[pb.NDArray.DType.ValueType, ext.NpDTypeLike]:
...


@t.overload
def dtypepb_to_npdtype_map(
version: t.Literal["v1alpha1"] = ...,
) -> dict[pb_v1alpha1.NDArray.DType.ValueType, ext.NpDTypeLike]:
...


@lru_cache(maxsize=2)
def dtypepb_to_npdtype_map(
version: str = LATEST_PROTOCOL_VERSION,
) -> dict[int, ext.NpDTypeLike]:
pb, _ = import_generated_stubs(version)
# pb.NDArray.Dtype -> np.dtype
return {
pb.NDArray.DTYPE_FLOAT: np.dtype("float32"),
Expand All @@ -74,9 +94,26 @@ def dtypepb_to_npdtype_map() -> dict[pb.NDArray.DType.ValueType, ext.NpDTypeLike
}


@lru_cache(maxsize=1)
def dtypepb_to_fieldpb_map() -> dict[pb.NDArray.DType.ValueType, str]:
return {k: npdtype_to_fieldpb_map()[v] for k, v in dtypepb_to_npdtype_map().items()}
@t.overload
def dtypepb_to_fieldpb_map(
version: t.Literal["v1"] = ...,
) -> dict[pb.NDArray.DType.ValueType, str]:
...


@t.overload
def dtypepb_to_fieldpb_map(
version: t.Literal["v1alpha1"] = ...,
) -> dict[pb_v1alpha1.NDArray.DType.ValueType, str]:
...


@lru_cache(maxsize=2)
def dtypepb_to_fieldpb_map(version: str = LATEST_PROTOCOL_VERSION) -> dict[int, str]:
return {
k: npdtype_to_fieldpb_map()[v]
for k, v in dtypepb_to_npdtype_map(version).items()
}


@lru_cache(maxsize=1)
Expand All @@ -85,10 +122,26 @@ def fieldpb_to_npdtype_map() -> dict[str, ext.NpDTypeLike]:
return {k: np.dtype(v) for k, v in FIELDPB_TO_NPDTYPE_NAME_MAP.items()}


@lru_cache(maxsize=1)
def npdtype_to_dtypepb_map() -> dict[ext.NpDTypeLike, pb.NDArray.DType.ValueType]:
@t.overload
def npdtype_to_dtypepb_map(
version: t.Literal["v1"] = ...,
) -> dict[ext.NpDTypeLike, pb.NDArray.DType.ValueType]:
...


@t.overload
def npdtype_to_dtypepb_map(
version: t.Literal["v1alpha1"] = ...,
) -> dict[ext.NpDTypeLike, pb_v1alpha1.NDArray.DType.ValueType]:
...


@lru_cache(maxsize=2)
def npdtype_to_dtypepb_map(
version: str = LATEST_PROTOCOL_VERSION,
) -> dict[ext.NpDTypeLike, int]:
# np.dtype -> pb.NDArray.Dtype
return {v: k for k, v in dtypepb_to_npdtype_map().items()}
return {v: k for k, v in dtypepb_to_npdtype_map(version).items()}


@lru_cache(maxsize=1)
Expand Down Expand Up @@ -251,7 +304,7 @@ def to_spec(self) -> dict[str, t.Any]:
}

@classmethod
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
def from_spec(cls, spec: dict[str, t.Any]) -> t.Self:
if "args" not in spec:
raise InvalidArgument(f"Missing args key in NumpyNdarray spec: {spec}")
res = NumpyNdarray(**spec["args"])
Expand Down Expand Up @@ -475,6 +528,15 @@ async def from_proto(self, field: pb.NDArray | bytes) -> ext.NpNDArray:
dtype: ext.NpDTypeLike = self._dtype
array = np.frombuffer(field, dtype=self._dtype)
else:
if isinstance(field, pb_v1alpha1.NDArray):
version = "v1alpha1"
elif isinstance(field, pb.NDArray):
version = "v1"
else:
raise BadInput(
f"Expected 'pb.NDArray' or 'bytes' but received {type(field)}"
) from None

# The behaviour of dtype are as follows:
# - if not provided:
# * All of the fields are empty, then we return a ``np.empty``.
Expand All @@ -486,11 +548,13 @@ async def from_proto(self, field: pb.NDArray | bytes) -> ext.NpNDArray:
dtype = None
else:
try:
dtype = dtypepb_to_npdtype_map()[field.dtype]
dtype = dtypepb_to_npdtype_map(version=version)[field.dtype]
except KeyError:
raise BadInput(f"{field.dtype} is invalid.") from None
if dtype is not None:
values_array = getattr(field, dtypepb_to_fieldpb_map()[field.dtype])
values_array = getattr(
field, dtypepb_to_fieldpb_map(version=version)[field.dtype]
)
else:
fieldpb = [
f.name for f, _ in field.ListFields() if f.name.endswith("_values")
Expand Down Expand Up @@ -520,7 +584,21 @@ async def from_proto(self, field: pb.NDArray | bytes) -> ext.NpNDArray:
# We will try to run validation process before sending this of to the user.
return self.validate_array(array)

async def to_proto(self, obj: ext.NpNDArray) -> pb.NDArray:
@t.overload
async def _to_proto_impl(
self, obj: ext.NpNDArray, *, version: t.Literal["v1"]
) -> pb.NDArray:
...

@t.overload
async def _to_proto_impl(
self, obj: ext.NpNDArray, *, version: t.Literal["v1alpha1"]
) -> pb_v1alpha1.NDArray:
...

async def _to_proto_impl(
self, obj: ext.NpNDArray, *, version: str
) -> _message.Message:
"""
Process given objects and convert it to grpc protobuf response.
Expand All @@ -535,9 +613,11 @@ async def to_proto(self, obj: ext.NpNDArray) -> pb.NDArray:
except BadInput as e:
raise e from None

pb, _ = import_generated_stubs(version)

try:
fieldpb = npdtype_to_fieldpb_map()[obj.dtype]
dtypepb = npdtype_to_dtypepb_map()[obj.dtype]
dtypepb = npdtype_to_dtypepb_map(version=version)[obj.dtype]
return pb.NDArray(
dtype=dtypepb,
shape=tuple(obj.shape),
Expand All @@ -547,3 +627,9 @@ async def to_proto(self, obj: ext.NpNDArray) -> pb.NDArray:
raise BadInput(
f"Unsupported dtype '{obj.dtype}' for response message.",
) from None

async def to_proto(self, obj: ext.NpNDArray) -> pb.NDArray:
return await self._to_proto_impl(obj, version="v1")

async def to_proto_v1alpha1(self, obj: ext.NpNDArray) -> pb_v1alpha1.NDArray:
return await self._to_proto_impl(obj, version="v1alpha1")
58 changes: 52 additions & 6 deletions src/bentoml/_internal/io_descriptors/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,18 @@
if TYPE_CHECKING:
import numpy as np
import pandas as pd
from typing_extensions import Self
from google.protobuf import message as _message

from bentoml.grpc.v1 import service_pb2 as pb
from bentoml.grpc.v1alpha1 import service_pb2 as pb_v1alpha1

from .. import external_typing as ext
from .base import OpenAPIResponse
from ..context import InferenceApiContext as Context

else:
pb, _ = import_generated_stubs()
pb, _ = import_generated_stubs("v1")
pb_v1alpha1, _ = import_generated_stubs("v1alpha1")
pd = LazyLoader("pd", globals(), "pandas", exc_msg=EXC_MSG)
np = LazyLoader("np", globals(), "numpy")

Expand Down Expand Up @@ -460,7 +462,7 @@ def to_spec(self) -> dict[str, t.Any]:
}

@classmethod
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
def from_spec(cls, spec: dict[str, t.Any]) -> t.Self:
if "args" not in spec:
raise InvalidArgument(f"Missing args key in PandasDataFrame spec: {spec}")
res = PandasDataFrame(**spec["args"])
Expand Down Expand Up @@ -673,7 +675,21 @@ def process_columns_contents(content: pb.Series) -> dict[str, t.Any]:
)
return self.validate_dataframe(dataframe)

async def to_proto(self, obj: ext.PdDataFrame) -> pb.DataFrame:
@t.overload
async def _to_proto_impl(
self, obj: ext.PdDataFrame, *, version: t.Literal["v1"]
) -> pb.DataFrame:
...

@t.overload
async def _to_proto_impl(
self, obj: ext.PdDataFrame, *, version: t.Literal["v1alpha1"]
) -> pb_v1alpha1.DataFrame:
...

async def _to_proto_impl(
self, obj: ext.PdDataFrame, *, version: str
) -> _message.Message:
"""
Process given objects and convert it to grpc protobuf response.
Expand All @@ -686,6 +702,8 @@ async def to_proto(self, obj: ext.PdDataFrame) -> pb.DataFrame:
"""
from .numpy import npdtype_to_fieldpb_map

pb, _ = import_generated_stubs(version)

# TODO: support different serialization format
obj = self.validate_dataframe(obj)
mapping = npdtype_to_fieldpb_map()
Expand Down Expand Up @@ -713,6 +731,12 @@ async def to_proto(self, obj: ext.PdDataFrame) -> pb.DataFrame:
],
)

async def to_proto(self, obj: ext.PdDataFrame) -> pb.DataFrame:
return await self._to_proto_impl(obj, version="v1")

async def to_proto_v1alpha1(self, obj: ext.PdDataFrame) -> pb_v1alpha1.DataFrame:
return await self._to_proto_impl(obj, version="v1alpha1")


class PandasSeries(
IODescriptor["ext.PdSeries"], descriptor_id="bentoml.io.PandasSeries"
Expand Down Expand Up @@ -904,7 +928,7 @@ def to_spec(self) -> dict[str, t.Any]:
}

@classmethod
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
def from_spec(cls, spec: dict[str, t.Any]) -> t.Self:
if "args" not in spec:
raise InvalidArgument(f"Missing args key in PandasSeries spec: {spec}")
res = PandasSeries(**spec["args"])
Expand Down Expand Up @@ -1068,7 +1092,21 @@ async def from_proto(self, field: pb.Series | bytes) -> ext.PdSeries:

return self.validate_series(series)

async def to_proto(self, obj: ext.PdSeries) -> pb.Series:
@t.overload
async def _to_proto_impl(
self, obj: ext.PdSeries, *, version: t.Literal["v1"]
) -> pb.Series:
...

@t.overload
async def _to_proto_impl(
self, obj: ext.PdSeries, *, version: t.Literal["v1alpha1"]
) -> pb_v1alpha1.Series:
...

async def _to_proto_impl(
self, obj: ext.PdSeries, *, version: str
) -> _message.Message:
"""
Process given objects and convert it to grpc protobuf response.
Expand All @@ -1081,6 +1119,8 @@ async def to_proto(self, obj: ext.PdSeries) -> pb.Series:
"""
from .numpy import npdtype_to_fieldpb_map

pb, _ = import_generated_stubs(version)

try:
obj = self.validate_series(obj, exception_cls=InvalidArgument)
except InvalidArgument as e:
Expand All @@ -1100,3 +1140,9 @@ async def to_proto(self, obj: ext.PdSeries) -> pb.Series:
raise InvalidArgument(
f"Unsupported dtype '{obj.dtype}' for response message."
) from None

async def to_proto(self, obj: ext.PdSeries) -> pb.Series:
return await self._to_proto_impl(obj, version="v1")

async def to_proto_v1alpha1(self, obj: ext.PdSeries) -> pb_v1alpha1.Series:
return await self._to_proto_impl(obj, version="v1alpha1")

0 comments on commit 4aa20e0

Please sign in to comment.