Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: apply pep574 out-of-band pickling to DefaultContainer #3736

Merged
merged 4 commits into from
Apr 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 75 additions & 59 deletions src/bentoml/_internal/runner/container.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,15 @@
from __future__ import annotations

import abc
import sys
import base64
import pickle
import typing as t
import itertools

from simple_di import inject
from simple_di import Provide

from ..types import LazyType
from ..utils import LazyLoader
from ..configuration.containers import BentoMLContainer

if sys.version_info < (3, 8):
import pickle5 as pickle
else:
import pickle

from ..utils.pickle import pep574_dumps
from ..utils.pickle import pep574_loads

SingleType = t.TypeVar("SingleType")
BatchType = t.TypeVar("BatchType")
Expand All @@ -41,7 +33,7 @@

class Payload(t.NamedTuple):
data: bytes
meta: dict[str, bool | int | float | str]
meta: dict[str, bool | int | float | str | list[int]]
container: str
batch_size: int = -1

Expand All @@ -52,7 +44,7 @@ def create_payload(
cls,
data: bytes,
batch_size: int,
meta: dict[str, bool | int | float | str] | None = None,
meta: dict[str, bool | int | float | str | list[int]] | None = None,
) -> Payload:
return Payload(data, meta or {}, container=cls.__name__, batch_size=batch_size)

Expand Down Expand Up @@ -289,21 +281,23 @@ def to_payload(
# skip 0-dimensional array
if batch.shape:

buffers: list[pickle.PickleBuffer] = []
if not (batch.flags["C_CONTIGUOUS"] or batch.flags["F_CONTIGUOUS"]):
# TODO: use fortan contiguous if it's faster
batch = np.ascontiguousarray(batch)
bs = pickle.dumps(batch, protocol=5, buffer_callback=buffers.append)

bs: bytes
concat_buffer_bs: bytes
indices: list[int]
bs, concat_buffer_bs, indices = pep574_dumps(batch)
bs_str = base64.b64encode(bs).decode("ascii")
buffer_bs = buffers[0].raw().tobytes()
# release memory
buffers[0].release()

return cls.create_payload(
buffer_bs,
concat_buffer_bs,
batch.shape[batch_dim],
{
"format": "pickle5",
"pickle_bytes": bs_str,
"pickle_bytes_str": bs_str,
"indices": indices,
},
)

Expand All @@ -320,10 +314,10 @@ def from_payload(
) -> ext.NpNDArray:
format = payload.meta.get("format", "default")
if format == "pickle5":
bs_str = payload.meta["pickle_bytes"]
bs_str = t.cast(str, payload.meta["pickle_bytes_str"])
bs = base64.b64decode(bs_str)
recovered_buffers = [pickle.PickleBuffer(payload.data)]
return pickle.loads(bs, buffers=recovered_buffers)
indices = t.cast(t.List[int], payload.meta["indices"])
return t.cast("ext.NpNDArray", pep574_loads(bs, payload.data, indices))

return pickle.loads(payload.data)

Expand All @@ -340,7 +334,6 @@ def batch_to_payloads(
return payloads

@classmethod
@inject
def from_batch_payloads(
cls,
payloads: t.Sequence[Payload],
Expand Down Expand Up @@ -387,12 +380,10 @@ def batch_to_batches(
]

@classmethod
@inject
def to_payload(
cls,
batch: ext.PdDataFrame | ext.PdSeries,
batch_dim: int,
plasma_db: ext.PlasmaClient | None = Provide[BentoMLContainer.plasma_db],
) -> Payload:
import pandas as pd

Expand All @@ -403,59 +394,60 @@ def to_payload(
if isinstance(batch, pd.Series):
batch = pd.DataFrame([batch])

if plasma_db:
return cls.create_payload(
plasma_db.put(batch).binary(),
batch.size,
{"plasma": True},
)
meta: dict[str, bool | int | float | str | list[int]] = {"format": "pickle5"}

bs: bytes
concat_buffer_bs: bytes
indices: list[int]
bs, concat_buffer_bs, indices = pep574_dumps(batch)

if indices:
meta["with_buffer"] = True
data = concat_buffer_bs
meta["pickle_bytes_str"] = base64.b64encode(bs).decode("ascii")
meta["indices"] = indices
else:
meta["with_buffer"] = False
data = bs

return cls.create_payload(
pickle.dumps(batch),
data,
batch.size,
{"plasma": False},
meta=meta,
)

@classmethod
@inject
def from_payload(
cls,
payload: Payload,
plasma_db: ext.PlasmaClient | None = Provide[BentoMLContainer.plasma_db],
) -> ext.PdDataFrame:
if payload.meta.get("plasma"):
import pyarrow.plasma as plasma

assert plasma_db
return plasma_db.get(plasma.ObjectID(payload.data))

return pickle.loads(payload.data)
if payload.meta["with_buffer"]:
bs_str = t.cast(str, payload.meta["pickle_bytes_str"])
bs = base64.b64decode(bs_str)
indices = t.cast(t.List[int], payload.meta["indices"])
return pep574_loads(bs, payload.data, indices)
else:
return pep574_loads(payload.data, b"", [])

@classmethod
@inject
def batch_to_payloads(
cls,
batch: ext.PdDataFrame,
indices: t.Sequence[int],
batch_dim: int = 0,
plasma_db: ext.PlasmaClient | None = Provide[BentoMLContainer.plasma_db],
) -> list[Payload]:
batches = cls.batch_to_batches(batch, indices, batch_dim)

payloads = [
cls.to_payload(subbatch, batch_dim, plasma_db) for subbatch in batches
]
payloads = [cls.to_payload(subbatch, batch_dim) for subbatch in batches]
return payloads

@classmethod
@inject
def from_batch_payloads( # pylint: disable=arguments-differ
cls,
payloads: t.Sequence[Payload],
batch_dim: int = 0,
plasma_db: ext.PlasmaClient | None = Provide[BentoMLContainer.plasma_db],
) -> tuple[ext.PdDataFrame, list[int]]:
batches = [cls.from_payload(payload, plasma_db) for payload in payloads]
batches = [cls.from_payload(payload) for payload in payloads]
return cls.batches_to_batch(batches, batch_dim)


Expand Down Expand Up @@ -487,20 +479,45 @@ def batch_to_batches(
def to_payload(cls, batch: t.Any, batch_dim: int) -> Payload:
if isinstance(batch, t.Generator): # Generators can't be pickled
batch = list(t.cast(t.Generator[t.Any, t.Any, t.Any], batch))

meta: dict[str, bool | int | float | str | list[int]] = {"format": "pickle5"}

bs: bytes
concat_buffer_bs: bytes
indices: list[int]
bs, concat_buffer_bs, indices = pep574_dumps(batch)

if indices:
meta["with_buffer"] = True
data = concat_buffer_bs
meta["pickle_bytes_str"] = base64.b64encode(bs).decode("ascii")
aarnphm marked this conversation as resolved.
Show resolved Hide resolved
meta["indices"] = indices
else:
meta["with_buffer"] = False
data = bs

if isinstance(batch, list):
return cls.create_payload(
pickle.dumps(batch), len(t.cast(t.List[t.Any], batch))
)
batch_size = len(t.cast(t.List[t.Any], batch))
else:
return cls.create_payload(pickle.dumps(batch), 1)
batch_size = 1

return cls.create_payload(
data=data,
batch_size=batch_size,
meta=meta,
)

@classmethod
@inject
def from_payload(cls, payload: Payload) -> t.Any:
return pickle.loads(payload.data)
if payload.meta["with_buffer"]:
bs_str = t.cast(str, payload.meta["pickle_bytes_str"])
bs = base64.b64decode(bs_str)
indices = t.cast(t.List[int], payload.meta["indices"])
return pep574_loads(bs, payload.data, indices)
else:
return pep574_loads(payload.data, b"", [])

@classmethod
@inject
def batch_to_payloads(
cls,
batch: list[t.Any],
Expand All @@ -513,7 +530,6 @@ def batch_to_payloads(
return payloads

@classmethod
@inject
def from_batch_payloads(
cls,
payloads: t.Sequence[Payload],
Expand Down
50 changes: 50 additions & 0 deletions src/bentoml/_internal/utils/pickle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from __future__ import annotations

import sys
import typing as t

if sys.version_info < (3, 8):
import pickle5 as pickle
else:
import pickle


# Pickle protocol 5 with out-of-band data
# https://peps.python.org/pep-0574/
def pep574_dumps(obj: t.Any) -> tuple[bytes, bytes, list[int]]:
buffers: list[pickle.PickleBuffer] = []
main_bytes: bytes = pickle.dumps(obj, protocol=5, buffer_callback=buffers.append)

if not buffers:
return main_bytes, b"", []

buffer_bytess: list[bytes] = [buff.raw().tobytes() for buff in buffers]

for buff in buffers:
buff.release()

indices: list[int] = [0]
for buff_bytes in buffer_bytess:
start = indices[-1]
end = start + len(buff_bytes)
indices.append(end)

concat_buffer_bytes: bytes = b"".join(buffer_bytess)
return main_bytes, concat_buffer_bytes, indices


def pep574_loads(
main_bytes: bytes, concat_buffer_bytes: bytes, indices: list[int]
) -> t.Any:

if not indices:
return pickle.loads(main_bytes)

mem = memoryview(concat_buffer_bytes)
partitions = zip(indices, indices[1:])
recover_buffers: list[pickle.PickleBuffer] = []
for partition in partitions:
buff = pickle.PickleBuffer(mem[slice(*partition)])
recover_buffers.append(buff)

return pickle.loads(main_bytes, buffers=recover_buffers)
59 changes: 59 additions & 0 deletions tests/unit/_internal/utils/test_pickle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import annotations

import typing as t

if t.TYPE_CHECKING:
import bentoml._internal.external_typing as ext


def test_pep574_restore() -> None:
import numpy as np
import pandas as pd

from bentoml._internal.utils.pickle import pep574_dumps
from bentoml._internal.utils.pickle import pep574_loads

arr1: ext.NpNDArray = np.random.uniform(size=(20, 20))
arr2: ext.NpNDArray = np.random.uniform(size=(64, 64))
arr3: ext.NpNDArray = np.random.uniform(size=(72, 72))

lst = [arr1, arr2, arr3]

bs: bytes
concat_buffer_bs: bytes
indices: list[int]
bs, concat_buffer_bs, indices = pep574_dumps(lst)
restored = t.cast(
t.List["ext.NpNDArray"], pep574_loads(bs, concat_buffer_bs, indices)
aarnphm marked this conversation as resolved.
Show resolved Hide resolved
)
for idx, arr in enumerate(lst):
assert np.isclose(arr, restored[idx]).all()

dic: dict[str, ext.NpNDArray] = dict(a=arr1, b=arr2, c=arr3)
bs, concat_buffer_bs, indices = pep574_dumps(dic)
restored = t.cast(
t.Dict[str, "ext.NpNDArray"], pep574_loads(bs, concat_buffer_bs, indices)
)
for key, arr in dic.items():
assert np.isclose(arr, restored[key]).all()

df1: ext.PdDataFrame = pd.DataFrame(arr1)
df2: ext.PdDataFrame = pd.DataFrame(arr2)
df3: ext.PdDataFrame = pd.DataFrame(arr3)

df_lst = [df1, df2, df3]

bs, concat_buffer_bs, indices = pep574_dumps(df_lst)
restored = t.cast(
t.List["ext.PdDataFrame"], pep574_loads(bs, concat_buffer_bs, indices)
)
for idx, df in enumerate(df_lst):
assert np.isclose(df.to_numpy(), restored[idx].to_numpy()).all()

df_dic: dict[str, ext.PdDataFrame] = dict(a=df1, b=df2, c=df3)
bs, concat_buffer_bs, indices = pep574_dumps(df_dic)
restored = t.cast(
t.Dict[str, "ext.PdDataFrame"], pep574_loads(bs, concat_buffer_bs, indices)
)
for key, df in df_dic.items():
assert np.isclose(df.to_numpy(), restored[key].to_numpy()).all()