Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

降低显存的方法 (供参考) #45

Open
peizhiluo007 opened this issue Jun 15, 2024 · 12 comments · May be fixed by #46
Open

降低显存的方法 (供参考) #45

peizhiluo007 opened this issue Jun 15, 2024 · 12 comments · May be fixed by #46

Comments

@peizhiluo007
Copy link

peizhiluo007 commented Jun 15, 2024

在下面函数(ComfyUI-IDM-VTON\src\nodes\pipeline_loader.py),
def load_pipeline(self, weight_dtype):

修改两点:
修改1> 把所有的 .to(DEVICE) ,全部注释掉,所有的。

image

修改2> 函数结尾处
修改前:
pipe.unet_encoder = unet_encoder
pipe = pipe.to(DEVICE)
pipe.weight_dtype = weight_dtype
修改为:
image

在显卡12G测试,完全无压力。查看显存占用大概6G多点,估计在8G下也能跑。

@peizhiluo007
Copy link
Author

peizhiluo007 commented Jun 15, 2024

def load_pipeline(self, weight_dtype):
    if weight_dtype == "float32":
        weight_dtype = torch.float32
    elif weight_dtype == "float16":
        weight_dtype = torch.float16
    elif weight_dtype == "bfloat16":
        weight_dtype = torch.bfloat16
    noise_scheduler = DDPMScheduler.from_pretrained(
        WEIGHTS_PATH, 
        subfolder="scheduler"
    )
    vae = AutoencoderKL.from_pretrained(
        WEIGHTS_PATH,
        subfolder="vae",
        torch_dtype=weight_dtype
    ).requires_grad_(False).eval()#.to(DEVICE)
    unet = UNet2DConditionModel.from_pretrained(
        WEIGHTS_PATH,
        subfolder="unet",
        torch_dtype=weight_dtype
    ).requires_grad_(False).eval()#.to(DEVICE)
    image_encoder = CLIPVisionModelWithProjection.from_pretrained(
        WEIGHTS_PATH,
        subfolder="image_encoder",
        torch_dtype=weight_dtype
    ).requires_grad_(False).eval()#.to(DEVICE)
    unet_encoder = UNet2DConditionModel_ref.from_pretrained(
        WEIGHTS_PATH,
        subfolder="unet_encoder",
        torch_dtype=weight_dtype
    ).requires_grad_(False).eval()#.to(DEVICE)
    text_encoder_one = CLIPTextModel.from_pretrained(
        WEIGHTS_PATH,
        subfolder="text_encoder",
        torch_dtype=weight_dtype
    ).requires_grad_(False).eval()#.to(DEVICE)
    text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
        WEIGHTS_PATH,
        subfolder="text_encoder_2",
        torch_dtype=weight_dtype
    ).requires_grad_(False).eval()#.to(DEVICE)
    tokenizer_one = AutoTokenizer.from_pretrained(
        WEIGHTS_PATH,
        subfolder="tokenizer",
        revision=None,
        use_fast=False,
    )
    tokenizer_two = AutoTokenizer.from_pretrained(
        WEIGHTS_PATH,
        subfolder="tokenizer_2",
        revision=None,
        use_fast=False,
    )
    pipe = TryonPipeline.from_pretrained(
        WEIGHTS_PATH,
        unet=unet,
        vae=vae,
        feature_extractor=CLIPImageProcessor(),
        text_encoder=text_encoder_one,
        text_encoder_2=text_encoder_two,
        tokenizer=tokenizer_one,
        tokenizer_2=tokenizer_two,
        scheduler=noise_scheduler,
        image_encoder=image_encoder,
        torch_dtype=weight_dtype,
    )
    pipe.weight_dtype = weight_dtype
    pipe.unet_encoder = unet_encoder
    pipe.enable_sequential_cpu_offload()
    pipe.unet_encoder.to(DEVICE)
    #pipe.to(DEVICE)
    #
    return (pipe, )

@TemryL
Copy link
Owner

TemryL commented Jun 15, 2024

Wow that's awesome! Thanks! Could you open a PR with these changes?

@lldacing
Copy link

速度会变慢吗

@TemryL TemryL linked a pull request Jun 18, 2024 that will close this issue
@qtmssa
Copy link

qtmssa commented Jun 20, 2024

Does this work?

@qtmssa
Copy link

qtmssa commented Jun 20, 2024

My 4090 (with 24G vRAM) still OOM, :-(
Anybody help? :-)

@deepfree2023
Copy link

Works without problem.

@peizhiluo007
Copy link
Author

peizhiluo007 commented Jun 22, 2024

Wow that's awesome! Thanks! Could you open a PR with these changes?

ok,i have submitted for review.
and i think It's better to add the lowvram option , so.

@dachangqing
Copy link

大佬,为什么我照你说的改了代码,还是出现内存不足的报错啊

@925-Studio
Copy link

Thanks for sharing this tip, it works fine.

@TemryL
Copy link
Owner

TemryL commented Jul 2, 2024

Wow awesome, thank you so much for this finding! Could you create a PR for this?

@Jeff-goal
Copy link

升级以后这个方法报错,vton无法导入了,请问该如何修改?

@zhucenichenghao
Copy link

it works,老铁

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

9 participants