Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Xl python inference #261

Merged
merged 4 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 107 additions & 31 deletions python_coreml_stable_diffusion/coreml_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import coremltools as ct

import logging
import json

logging.basicConfig()
logger = logging.getLogger(__name__)
Expand All @@ -21,14 +22,47 @@ class CoreMLModel:
""" Wrapper for running CoreML models using coremltools
"""

def __init__(self, model_path, compute_unit):
assert os.path.exists(model_path) and model_path.endswith(".mlpackage")
def __init__(self, model_path, compute_unit, sources='packages'):

logger.info(f"Loading {model_path}")

start = time.time()
self.model = ct.models.MLModel(
model_path, compute_units=ct.ComputeUnit[compute_unit])
if sources == 'packages':
assert os.path.exists(model_path) and model_path.endswith(".mlpackage")

self.model = ct.models.MLModel(
model_path, compute_units=ct.ComputeUnit[compute_unit])
DTYPE_MAP = {
65552: np.float16,
65568: np.float32,
131104: np.int32,
}
self.expected_inputs = {
input_tensor.name: {
"shape": tuple(input_tensor.type.multiArrayType.shape),
"dtype": DTYPE_MAP[input_tensor.type.multiArrayType.dataType],
}
for input_tensor in self.model._spec.description.input
}
elif sources == 'compiled':
assert os.path.exists(model_path) and model_path.endswith(".mlmodelc")

self.model = ct.models.CompiledMLModel(model_path, ct.ComputeUnit[compute_unit])

# Grab expected inputs from metadata.json
with open(os.path.join(model_path, 'metadata.json'), 'r') as f:
config = json.load(f)[0]

self.expected_inputs = {
input_tensor['name']: {
"shape": tuple(eval(input_tensor['shape'])),
"dtype": np.dtype(input_tensor['dataType'].lower()),
}
for input_tensor in config['inputSchema']
}
else:
raise ValueError(f'Expected `packages` or `compiled` for sources, received {sources}')

load_time = time.time() - start
logger.info(f"Done. Took {load_time:.1f} seconds.")

Expand All @@ -38,21 +72,6 @@ def __init__(self, model_path, compute_unit):
"The Swift package we provide uses precompiled Core ML models (.mlmodelc) to avoid compile-on-load."
)


DTYPE_MAP = {
65552: np.float16,
65568: np.float32,
131104: np.int32,
}

self.expected_inputs = {
input_tensor.name: {
"shape": tuple(input_tensor.type.multiArrayType.shape),
"dtype": DTYPE_MAP[input_tensor.type.multiArrayType.dataType],
}
for input_tensor in self.model._spec.description.input
}

def _verify_inputs(self, **kwargs):
for k, v in kwargs.items():
if k in self.expected_inputs:
Expand All @@ -72,7 +91,7 @@ def _verify_inputs(self, **kwargs):
f"Expected shape {expected_shape}, got {v.shape} for input: {k}"
)
else:
raise ValueError("Received unexpected input kwarg: {k}")
raise ValueError(f"Received unexpected input kwarg: {k}")

def __call__(self, **kwargs):
self._verify_inputs(**kwargs)
Expand All @@ -82,21 +101,77 @@ def __call__(self, **kwargs):
LOAD_TIME_INFO_MSG_TRIGGER = 10 # seconds


def _load_mlpackage(submodule_name, mlpackages_dir, model_version,
compute_unit):
""" Load Core ML (mlpackage) models from disk (As exported by torch2coreml.py)
def get_resource_type(resources_dir: str) -> str:
"""
Detect resource type based on filepath extensions.
returns:
`packages`: for .mlpackage resources
'compiled`: for .mlmodelc resources
"""
logger.info(f"Loading {submodule_name} mlpackage")
directories = [f for f in os.listdir(resources_dir) if os.path.isdir(os.path.join(resources_dir, f))]

fname = f"Stable_Diffusion_version_{model_version}_{submodule_name}.mlpackage".replace(
"/", "_")
mlpackage_path = os.path.join(mlpackages_dir, fname)
# consider directories ending with extension
extensions = set([os.path.splitext(e)[1] for e in directories if os.path.splitext(e)[1]])

if not os.path.exists(mlpackage_path):
raise FileNotFoundError(
f"{submodule_name} CoreML model doesn't exist at {mlpackage_path}")
# if one extension present we may be able to infer sources type
if len(set(extensions)) == 1:
extension = extensions.pop()
else:
raise ValueError(f'Multiple file extensions found at {resources_dir}.'
f'Cannot infer resource type from contents.')

if extension == '.mlpackage':
sources = 'packages'
elif extension == '.mlmodelc':
sources = 'compiled'
else:
raise ValueError(f'Did not find .mlpackage or .mlmodelc at {resources_dir}')

return sources


def _load_mlpackage(submodule_name,
mlpackages_dir,
model_version,
compute_unit,
sources=None):
"""
Load Core ML (mlpackage) models from disk (As exported by torch2coreml.py)

"""

# if sources not provided, attempt to infer `packages` or `compiled` from the
# resources directory
if sources is None:
sources = get_resource_type(mlpackages_dir)

if sources == 'packages':
logger.info(f"Loading {submodule_name} mlpackage")
fname = f"Stable_Diffusion_version_{model_version}_{submodule_name}.mlpackage".replace(
"/", "_")
mlpackage_path = os.path.join(mlpackages_dir, fname)

if not os.path.exists(mlpackage_path):
raise FileNotFoundError(
f"{submodule_name} CoreML model doesn't exist at {mlpackage_path}")

elif sources == 'compiled':
logger.info(f"Loading {submodule_name} mlmodelc")

# FixMe: Submodule names and compiled resources names differ. Can change if names match in the future.
submodule_names = ["text_encoder", "text_encoder_2", "unet", "vae_decoder"]
compiled_names = ['TextEncoder', 'TextEncoder2', 'Unet', 'VAEDecoder', 'VAEEncoder']
name_map = dict(zip(submodule_names, compiled_names))

cname = name_map[submodule_name] + '.mlmodelc'
mlpackage_path = os.path.join(mlpackages_dir, cname)

if not os.path.exists(mlpackage_path):
raise FileNotFoundError(
f"{submodule_name} CoreML model doesn't exist at {mlpackage_path}")

return CoreMLModel(mlpackage_path, compute_unit, sources=sources)

return CoreMLModel(mlpackage_path, compute_unit)

def _load_mlpackage_controlnet(mlpackages_dir, model_version, compute_unit):
""" Load Core ML (mlpackage) models from disk (As exported by torch2coreml.py)
Expand All @@ -115,5 +190,6 @@ def _load_mlpackage_controlnet(mlpackages_dir, model_version, compute_unit):

return CoreMLModel(mlpackage_path, compute_unit)


def get_available_compute_units():
return tuple(cu for cu in ct.ComputeUnit._member_names_)
Loading