Skip to content

Commit

Permalink
Improve ZLUDA installation.
Browse files Browse the repository at this point in the history
  • Loading branch information
lshqqytiger committed Apr 29, 2024
1 parent 6130ef9 commit 620b78c
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 36 deletions.
1 change: 1 addition & 0 deletions modules/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
parser.add_argument("--use-cpu-torch", action="store_true", help="use torch built with cpu")
parser.add_argument("--use-directml", action="store_true", help="use DirectML device as torch device")
parser.add_argument("--use-zluda", action="store_true", help="use ZLUDA device as torch device")
parser.add_argument("--use-zluda-dnn", action="store_true", help="enable ZLUDA DNN")
parser.add_argument("--use-ipex", action="store_true", help="use Intel XPU as torch device")
parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model")
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
Expand Down
4 changes: 2 additions & 2 deletions modules/img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from modules import images
from modules.infotext_utils import create_override_settings_dict, parse_generation_parameters
from modules.processing import Processed, process_images
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state
from modules.sd_models import get_closet_checkpoint_match
import modules.shared as shared
Expand Down Expand Up @@ -186,7 +186,7 @@ def img2img(id_task: str, request: gr.Request, mode: int, prompt: str, negative_

assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'

p = processing.StableDiffusionProcessingImg2Img(
p = StableDiffusionProcessingImg2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
Expand Down
59 changes: 27 additions & 32 deletions modules/launch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,13 +437,13 @@ def prepare_environment():
rocm_found = False
hip_found = False
backend = "cuda"
torch_command = "pip install torch==2.2.2 torchvision" if args.use_cpu_torch else "pip install torch==2.2.2 torchvision --extra-index-url https://download.pytorch.org/whl/cu121"
torch_command = "pip install torch==2.3.0 torchvision" if args.use_cpu_torch else "pip install torch==2.2.2 torchvision --extra-index-url https://download.pytorch.org/whl/cu121"

if args.use_cpu_torch:
backend = "cpu"
torch_command = os.environ.get(
"TORCH_COMMAND",
"pip install torch==2.2.2 torchvision",
"pip install torch==2.3.0 torchvision",
)
elif args.use_directml:
backend = "directml"
Expand All @@ -460,32 +460,23 @@ def prepare_environment():
)
torch_command = os.environ.get(
"TORCH_COMMAND",
f"pip install torch==2.2.2 torchvision --index-url {torch_index_url}",
f"pip install torch==2.3.0 torchvision --index-url {torch_index_url}",
)
zluda_path = find_zluda()
if zluda_path is None:
is_windows = system == "Windows"
import urllib.request
import zipfile
import tarfile
archive_type = zipfile.ZipFile if is_windows else tarfile.TarFile
try:
urllib.request.urlretrieve(f'https://github.com/lshqqytiger/ZLUDA/releases/download/rel.9e97c717c3fef536d3116f39a15d95626c1dfe39/ZLUDA-{platform.system().lower()}-amd64.{"zip" if is_windows else "tar.gz"}', '_zluda')
with archive_type('_zluda', 'r') as f:
f.extractall('.zluda')
zluda_path = os.path.abspath('./.zluda')
os.remove('_zluda')
except Exception as e:
print(f'Failed to install ZLUDA: {e}')
if os.path.exists(os.path.join(zluda_path, 'nvcuda.dll')):
print(f'Using ZLUDA in {zluda_path}')
torch_command = os.environ.get(
'TORCH_COMMAND',
'pip install torch==2.2.2 torchvision --index-url https://download.pytorch.org/whl/cu118',
)
paths = os.environ.get('PATH', '.')
if zluda_path not in paths:
os.environ['PATH'] = paths + ';' + zluda_path
try:
from modules import zluda_installer
if args.use_zluda_dnn:
if zluda_installer.check_dnn_dependency():
zluda_installer.enable_dnn()
else:
print("Couldn't find the required dependency of ZLUDA DNN.")
zluda_installer.install()
zluda_installer.resolve_path()
torch_command = os.environ.get('TORCH_COMMAND', 'pip install torch==2.3.0 torchvision --index-url https://download.pytorch.org/whl/cu118')
print(f'Using ZLUDA in {zluda_installer.ZLUDA_PATH}')
except Exception as e:
print(f'Failed to install ZLUDA: {e}')
print('Using CPU-only torch')
torch_command = os.environ.get('TORCH_COMMAND', 'pip install torch torchvision')
elif args.use_ipex:
backend = "ipex"
if system == "Windows":
Expand Down Expand Up @@ -519,7 +510,7 @@ def prepare_environment():
)
torch_command = os.environ.get(
"TORCH_COMMAND",
f"pip install torch==2.2.0 torchvision==0.17.0 --extra-index-url {torch_index_url}",
f"pip install torch==2.3.0 torchvision --extra-index-url {torch_index_url}",
)
elif system == "Windows" and hip_found: # ZLUDA
print("ROCm Toolkit was found.")
Expand All @@ -529,17 +520,17 @@ def prepare_environment():
)
torch_command = os.environ.get(
"TORCH_COMMAND",
f"pip install torch==2.2.1 torchvision --index-url {torch_index_url}",
f"pip install torch==2.3.0 torchvision --index-url {torch_index_url}",
)
elif rocm_found:
print("ROCm Toolkit was found.")
backend = "rocm"
torch_index_url = os.environ.get(
"TORCH_INDEX_URL", "https://download.pytorch.org/whl/rocm5.4.2"
"TORCH_INDEX_URL", "https://download.pytorch.org/whl/rocm6.0"
)
torch_command = os.environ.get(
"TORCH_COMMAND",
f"pip install torch==2.0.1 torchvision==0.15.2 --index-url {torch_index_url}",
f"pip install torch==2.3.0 torchvision --index-url {torch_index_url}",
)

requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
Expand Down Expand Up @@ -585,7 +576,11 @@ def prepare_environment():
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
startup_timer.record("install torch")
if args.use_zluda:
patch_zluda()
try:
from modules.zluda_installer import patch as patch_torch
patch_torch()
except Exception as e:
print(f'ZLUDA: failed to automatically patch torch: {e}')

if args.use_ipex or args.use_directml or args.use_cpu_torch:
args.skip_torch_cuda_test = True
Expand Down
14 changes: 12 additions & 2 deletions modules/zluda.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import platform
import torch
from torch._prims_common import DeviceLikeType
from modules import devices
from modules import shared, devices


conv2d = torch.nn.functional.conv2d
def conv2d_cudnn_disabled(*args, **kwargs):
torch.backends.cudnn.enabled = False
R = conv2d(*args, **kwargs)
torch.backends.cudnn.enabled = True
return R


def is_zluda(device: DeviceLikeType):
Expand All @@ -23,10 +31,12 @@ def test(device: DeviceLikeType):
def initialize_zluda():
device = devices.get_optimal_device()
if platform.system() == "Windows" and torch.cuda.is_available() and is_zluda(device):
torch.backends.cudnn.enabled = False
torch.backends.cudnn.enabled = shared.cmd_opts.use_zluda_dnn
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(False)
if shared.cmd_opts.use_zluda_dnn:
torch.nn.functional.conv2d = conv2d_cudnn_disabled
devices.device_codeformer = devices.cpu

if not test(device):
Expand Down
94 changes: 94 additions & 0 deletions modules/zluda_installer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import os
import shutil
import zipfile
import tarfile
import platform
import urllib.request


RELEASE = 'rel.9e97c717c3fef536d3116f39a15d95626c1dfe39'
TARGETS = {
'cublas.dll': 'cublas64_11.dll',
'cusparse.dll': 'cusparse64_11.dll',
'nvrtc.dll': 'nvrtc64_112_0.dll',
}
ZLUDA_PATH = None
TORCHLIB_PATH = None


def find_zluda_path():
zluda_path = os.environ.get('ZLUDA', None)
if zluda_path is None:
paths = os.environ.get('PATH', '').split(';')
for path in paths:
if os.path.exists(os.path.join(path, 'zluda_redirect.dll')):
zluda_path = path
break
return zluda_path


def find_venv_dir():
python_dir = os.path.dirname(shutil.which('python'))
if shutil.which('conda') is None:
python_dir = os.path.dirname(python_dir)
return os.environ.get('VENV_DIR', python_dir)


def reset_torch():
for dll in TARGETS.values():
path = os.path.join(TORCHLIB_PATH, dll)
if os.path.exists(path):
os.remove(path)


def is_patched():
for dll in TARGETS.values():
if not os.path.islink(os.path.join(TORCHLIB_PATH, dll)):
return False
return True


def check_dnn_dependency():
hip_path = os.environ.get("HIP_PATH", None)
if hip_path is None: # unable to check
return True
if os.path.exists(os.path.join(hip_path, 'bin', 'MIOpen.dll')):
return True
return False


def enable_dnn():
global RELEASE # pylint: disable=global-statement
TARGETS['cudnn.dll'] = 'cudnn64_8.dll'
RELEASE = 'v3.8-pre2-dnn'


def install():
global ZLUDA_PATH, TORCHLIB_PATH # pylint: disable=global-statement
ZLUDA_PATH = find_zluda_path()
TORCHLIB_PATH = os.path.join(find_venv_dir(), 'Lib', 'site-packages', 'torch', 'lib')

if ZLUDA_PATH is not None:
return

is_windows = platform.system() == 'Windows'
archive_type = zipfile.ZipFile if is_windows else tarfile.TarFile
urllib.request.urlretrieve(f'https://github.com/lshqqytiger/ZLUDA/releases/download/{RELEASE}/ZLUDA-{platform.system().lower()}-amd64.{"zip" if is_windows else "tar.gz"}', '_zluda')
with archive_type('_zluda', 'r') as f:
f.extractall('.zluda')
ZLUDA_PATH = os.path.abspath('./.zluda')
os.remove('_zluda')


def resolve_path():
paths = os.environ.get('PATH', '.')
if ZLUDA_PATH not in paths:
os.environ['PATH'] = paths + ';' + ZLUDA_PATH


def patch():
if is_patched():
return
reset_torch()
for k, v in TARGETS.items():
os.symlink(os.path.join(ZLUDA_PATH, k), os.path.join(TORCHLIB_PATH, v))

0 comments on commit 620b78c

Please sign in to comment.