Skip to content

Commit

Permalink
Xl python inference (#261)
Browse files Browse the repository at this point in the history
* Updated pipeline.py for XL inference

* cleaned up

* Add shape handling for UNET time_ids shape

* added support for loading from CompiledMLModel
  • Loading branch information
lopez-hector authored Sep 26, 2023
1 parent 94dfc6b commit f3a2124
Show file tree
Hide file tree
Showing 2 changed files with 415 additions and 171 deletions.
138 changes: 107 additions & 31 deletions python_coreml_stable_diffusion/coreml_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import coremltools as ct

import logging
import json

logging.basicConfig()
logger = logging.getLogger(__name__)
Expand All @@ -21,14 +22,47 @@ class CoreMLModel:
""" Wrapper for running CoreML models using coremltools
"""

def __init__(self, model_path, compute_unit):
assert os.path.exists(model_path) and model_path.endswith(".mlpackage")
def __init__(self, model_path, compute_unit, sources='packages'):

logger.info(f"Loading {model_path}")

start = time.time()
self.model = ct.models.MLModel(
model_path, compute_units=ct.ComputeUnit[compute_unit])
if sources == 'packages':
assert os.path.exists(model_path) and model_path.endswith(".mlpackage")

self.model = ct.models.MLModel(
model_path, compute_units=ct.ComputeUnit[compute_unit])
DTYPE_MAP = {
65552: np.float16,
65568: np.float32,
131104: np.int32,
}
self.expected_inputs = {
input_tensor.name: {
"shape": tuple(input_tensor.type.multiArrayType.shape),
"dtype": DTYPE_MAP[input_tensor.type.multiArrayType.dataType],
}
for input_tensor in self.model._spec.description.input
}
elif sources == 'compiled':
assert os.path.exists(model_path) and model_path.endswith(".mlmodelc")

self.model = ct.models.CompiledMLModel(model_path, ct.ComputeUnit[compute_unit])

# Grab expected inputs from metadata.json
with open(os.path.join(model_path, 'metadata.json'), 'r') as f:
config = json.load(f)[0]

self.expected_inputs = {
input_tensor['name']: {
"shape": tuple(eval(input_tensor['shape'])),
"dtype": np.dtype(input_tensor['dataType'].lower()),
}
for input_tensor in config['inputSchema']
}
else:
raise ValueError(f'Expected `packages` or `compiled` for sources, received {sources}')

load_time = time.time() - start
logger.info(f"Done. Took {load_time:.1f} seconds.")

Expand All @@ -38,21 +72,6 @@ def __init__(self, model_path, compute_unit):
"The Swift package we provide uses precompiled Core ML models (.mlmodelc) to avoid compile-on-load."
)


DTYPE_MAP = {
65552: np.float16,
65568: np.float32,
131104: np.int32,
}

self.expected_inputs = {
input_tensor.name: {
"shape": tuple(input_tensor.type.multiArrayType.shape),
"dtype": DTYPE_MAP[input_tensor.type.multiArrayType.dataType],
}
for input_tensor in self.model._spec.description.input
}

def _verify_inputs(self, **kwargs):
for k, v in kwargs.items():
if k in self.expected_inputs:
Expand All @@ -72,7 +91,7 @@ def _verify_inputs(self, **kwargs):
f"Expected shape {expected_shape}, got {v.shape} for input: {k}"
)
else:
raise ValueError("Received unexpected input kwarg: {k}")
raise ValueError(f"Received unexpected input kwarg: {k}")

def __call__(self, **kwargs):
self._verify_inputs(**kwargs)
Expand All @@ -82,21 +101,77 @@ def __call__(self, **kwargs):
LOAD_TIME_INFO_MSG_TRIGGER = 10 # seconds


def _load_mlpackage(submodule_name, mlpackages_dir, model_version,
compute_unit):
""" Load Core ML (mlpackage) models from disk (As exported by torch2coreml.py)
def get_resource_type(resources_dir: str) -> str:
"""
Detect resource type based on filepath extensions.
returns:
`packages`: for .mlpackage resources
'compiled`: for .mlmodelc resources
"""
logger.info(f"Loading {submodule_name} mlpackage")
directories = [f for f in os.listdir(resources_dir) if os.path.isdir(os.path.join(resources_dir, f))]

fname = f"Stable_Diffusion_version_{model_version}_{submodule_name}.mlpackage".replace(
"/", "_")
mlpackage_path = os.path.join(mlpackages_dir, fname)
# consider directories ending with extension
extensions = set([os.path.splitext(e)[1] for e in directories if os.path.splitext(e)[1]])

if not os.path.exists(mlpackage_path):
raise FileNotFoundError(
f"{submodule_name} CoreML model doesn't exist at {mlpackage_path}")
# if one extension present we may be able to infer sources type
if len(set(extensions)) == 1:
extension = extensions.pop()
else:
raise ValueError(f'Multiple file extensions found at {resources_dir}.'
f'Cannot infer resource type from contents.')

if extension == '.mlpackage':
sources = 'packages'
elif extension == '.mlmodelc':
sources = 'compiled'
else:
raise ValueError(f'Did not find .mlpackage or .mlmodelc at {resources_dir}')

return sources


def _load_mlpackage(submodule_name,
mlpackages_dir,
model_version,
compute_unit,
sources=None):
"""
Load Core ML (mlpackage) models from disk (As exported by torch2coreml.py)
"""

# if sources not provided, attempt to infer `packages` or `compiled` from the
# resources directory
if sources is None:
sources = get_resource_type(mlpackages_dir)

if sources == 'packages':
logger.info(f"Loading {submodule_name} mlpackage")
fname = f"Stable_Diffusion_version_{model_version}_{submodule_name}.mlpackage".replace(
"/", "_")
mlpackage_path = os.path.join(mlpackages_dir, fname)

if not os.path.exists(mlpackage_path):
raise FileNotFoundError(
f"{submodule_name} CoreML model doesn't exist at {mlpackage_path}")

elif sources == 'compiled':
logger.info(f"Loading {submodule_name} mlmodelc")

# FixMe: Submodule names and compiled resources names differ. Can change if names match in the future.
submodule_names = ["text_encoder", "text_encoder_2", "unet", "vae_decoder"]
compiled_names = ['TextEncoder', 'TextEncoder2', 'Unet', 'VAEDecoder', 'VAEEncoder']
name_map = dict(zip(submodule_names, compiled_names))

cname = name_map[submodule_name] + '.mlmodelc'
mlpackage_path = os.path.join(mlpackages_dir, cname)

if not os.path.exists(mlpackage_path):
raise FileNotFoundError(
f"{submodule_name} CoreML model doesn't exist at {mlpackage_path}")

return CoreMLModel(mlpackage_path, compute_unit, sources=sources)

return CoreMLModel(mlpackage_path, compute_unit)

def _load_mlpackage_controlnet(mlpackages_dir, model_version, compute_unit):
""" Load Core ML (mlpackage) models from disk (As exported by torch2coreml.py)
Expand All @@ -115,5 +190,6 @@ def _load_mlpackage_controlnet(mlpackages_dir, model_version, compute_unit):

return CoreMLModel(mlpackage_path, compute_unit)


def get_available_compute_units():
return tuple(cu for cu in ct.ComputeUnit._member_names_)
Loading

0 comments on commit f3a2124

Please sign in to comment.