Skip to content

Commit

Permalink
fix: make sure not to override user set values for from_sample (#3610)
Browse files Browse the repository at this point in the history
  • Loading branch information
aarnphm authored Mar 2, 2023
1 parent 699c9c2 commit 246ad6c
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 6 deletions.
6 changes: 4 additions & 2 deletions src/bentoml/_internal/io_descriptors/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,10 @@ async def predict(input: NDArray[np.int16]) -> NDArray[Any]:
raise BentoMLException(
f"Failed to create a 'numpy.ndarray' from given sample {sample}"
) from None
self._dtype = sample.dtype
self._shape = sample.shape
if self._dtype is None:
self._dtype = sample.dtype
if self._shape is None:
self._shape = sample.shape
return sample

async def from_proto(self, field: pb.NDArray | bytes) -> ext.NpNDArray:
Expand Down
12 changes: 8 additions & 4 deletions src/bentoml/_internal/io_descriptors/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,10 @@ def predict(inputs: pd.DataFrame) -> pd.DataFrame: ...
raise InvalidArgument(
f"Failed to create a 'pd.DataFrame' from sample {sample}: {e}"
) from None
self._shape = sample.shape
self._columns = [str(i) for i in list(sample.columns)]
if self._shape is None:
self._shape = sample.shape
if self._columns is None:
self._columns = [str(i) for i in list(sample.columns)]
if self._dtype is None:
self._dtype = sample.dtypes
return sample
Expand Down Expand Up @@ -933,8 +935,10 @@ def predict(inputs: pd.Series) -> pd.Series: ...
"""
if not isinstance(sample, pd.Series):
sample = pd.Series(sample)
self._dtype = sample.dtype
self._shape = sample.shape
if self._dtype is None:
self._dtype = sample.dtype
if self._shape is None:
self._shape = sample.shape
return sample

def input_type(self) -> LazyType[ext.PdSeries]:
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/_internal/io/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@ def test_verify_numpy_ndarray(caplog: LogCaptureFixture):
assert "Failed to reshape" in caplog.text


def test_from_sample_ensure_not_override():
example = NumpyNdarray.from_sample(np.ones((2, 2, 3)), dtype=np.float32)
assert example._dtype == np.float32

example = NumpyNdarray.from_sample(np.ones((2, 2, 3)), shape=(2, 2, 3))
assert example._shape == (2, 2, 3)


def generate_1d_array(dtype: pb.NDArray.DType.ValueType, length: int = 3):
if dtype == pb.NDArray.DTYPE_BOOL:
return [True] * length
Expand Down

0 comments on commit 246ad6c

Please sign in to comment.