Skip to content

Commit

Permalink
Small refactor.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Apr 7, 2023
1 parent 28a7205 commit bceccca
Showing 1 changed file with 3 additions and 11 deletions.
14 changes: 3 additions & 11 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ def load_model_gpu(model):
global current_loaded_model
global vram_state
global model_accelerated
global xpu_available

if model is current_loaded_model:
return
Expand All @@ -148,17 +147,14 @@ def load_model_gpu(model):
pass
elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM:
model_accelerated = False
if xpu_available:
real_model.to("xpu")
else:
real_model.cuda()
real_model.to(get_torch_device())
else:
if vram_state == VRAMState.NO_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
elif vram_state == VRAMState.LOW_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})

accelerate.dispatch_model(real_model, device_map=device_map, main_device="xpu" if xpu_available else "cuda")
accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device())
model_accelerated = True
return current_loaded_model

Expand All @@ -184,12 +180,8 @@ def load_controlnet_gpu(models):

def load_if_low_vram(model):
global vram_state
global xpu_available
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
if xpu_available:
return model.to("xpu")
else:
return model.cuda()
return model.to(get_torch_device())
return model

def unload_if_low_vram(model):
Expand Down

0 comments on commit bceccca

Please sign in to comment.