Skip to content

Commit

Permalink
Merge branch 'main' into chore/figlet-cli
Browse files Browse the repository at this point in the history
  • Loading branch information
sauyon authored Feb 23, 2023
2 parents 22f229a + 18e6415 commit ff4c7b2
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions src/bentoml/_internal/frameworks/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class DiffusersOptions(PartialKwargsModelOptions):
torch_dtype: str | torch.dtype | None = None
custom_pipeline: str | None = None
enable_xformers: bool | None = None
enable_attention_slicing: int | str | None = None


def get(tag_like: str | Tag) -> bentoml.Model:
Expand Down Expand Up @@ -83,6 +84,7 @@ def get(tag_like: str | Tag) -> bentoml.Model:

def load_model(
bento_model: str | Tag | bentoml.Model,
device_id: str | torch.device | None = None,
pipeline_class: type[
diffusers.pipelines.DiffusionPipeline
] = diffusers.StableDiffusionPipeline,
Expand All @@ -92,6 +94,7 @@ def load_model(
torch_dtype: str | torch.dtype | None = None,
low_cpu_mem_usage: bool | None = None,
enable_xformers: bool = False,
enable_attention_slicing: int | str | None = None,
) -> diffusers.DiffusionPipeline:
"""
Load a Diffusion model and convert it to diffusers `Pipeline <https://huggingface.co/docs/diffusers/api/pipelines/overview>`_
Expand All @@ -101,6 +104,8 @@ def load_model(
bento_model:
Either the tag of the model to get from the store, or a BentoML
``~bentoml.Model`` instance to load the model from.
device_id (:code:`str`, `optional`, default to :code:`None`):
Optional devices to put the given model on. Refer to `device attributes <https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device>`_.
pipeline_class (:code:`type[diffusers.DiffusionPipeline]`, `optional`):
DiffusionPipeline Class use to load the saved diffusion model, default to
``diffusers.StableDiffusionPipeline``. For more pipeline types, refer to
Expand Down Expand Up @@ -173,9 +178,15 @@ def load_model(
)
pipeline.scheduler = scheduler

if device_id is not None:
pipeline = pipeline.to(device_id)

if enable_xformers:
pipeline.enable_xformers_memory_efficient_attention()

if enable_attention_slicing is not None:
pipeline.enable_attention_slicing(enable_attention_slicing)

return pipeline


Expand Down Expand Up @@ -403,6 +414,9 @@ def get_runnable(bento_model: bentoml.Model) -> t.Type[bentoml.Runnable]:
] | None = bento_model.info.options.scheduler_class
custom_pipeline: str | None = bento_model.info.options.custom_pipeline
_enable_xformers: str | None = bento_model.info.options.enable_xformers
enable_attention_slicing: int | str | None = (
bento_model.info.options.enable_attention_slicing
)
_torch_dtype: str | torch.dtype | None = bento_model.info.options.torch_dtype

class DiffusersRunnable(bentoml.Runnable):
Expand All @@ -422,18 +436,21 @@ def __init__(self):
if is_xformers_available():
enable_xformers: bool = True

device_id: str | None = None
if torch.cuda.is_available():
device_id = "cuda"

self.pipeline: diffusers.DiffusionPipeline = load_model(
bento_model,
device_id=device_id,
pipeline_class=pipeline_class,
scheduler_class=scheduler_class,
torch_dtype=torch_dtype,
custom_pipeline=custom_pipeline,
enable_xformers=enable_xformers,
enable_attention_slicing=enable_attention_slicing,
)

if torch.cuda.is_available():
self.pipeline.to("cuda")

def make_run_method(
method_name: str, partial_kwargs: dict[str, t.Any] | None
) -> t.Callable[..., t.Any]:
Expand Down

0 comments on commit ff4c7b2

Please sign in to comment.