From 2842d34e5683a184cfc2669428dce88366bb72f6 Mon Sep 17 00:00:00 2001 From: Joe Meyer Date: Thu, 9 May 2024 12:49:49 -0500 Subject: [PATCH] pass model name to img2img --- stablediffusion_fastapi_multigpu/app.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/stablediffusion_fastapi_multigpu/app.py b/stablediffusion_fastapi_multigpu/app.py index bdcadab..6a12869 100644 --- a/stablediffusion_fastapi_multigpu/app.py +++ b/stablediffusion_fastapi_multigpu/app.py @@ -33,21 +33,9 @@ async def log_duration(request: Request, call_next): # Ensure model is on GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_name = os.environ.get("MODEL_NAME", "stabilityai/sdxl-turbo") -pipe_txt2img = StableDiffusionPipeline.from_pretrained( - "stabilityai/sdxl-turbo", torch_dtype=torch.float16 -) -pipe_img2img = StableDiffusionImg2ImgPipeline( - vae=pipe_txt2img.vae, - text_encoder=pipe_txt2img.text_encoder, - tokenizer=pipe_txt2img.tokenizer, - unet=pipe_txt2img.unet, - scheduler=pipe_txt2img.scheduler, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=False, -) + # Load model and assign it to available GPUs -num_gpus = 1 # torch.cuda.device_count() +num_gpus = torch.cuda.device_count() txt2img_pipes = [ StableDiffusionPipeline.from_pretrained( model_name, torch_dtype=torch.float16, variant="fp16" @@ -56,6 +44,7 @@ async def log_duration(request: Request, call_next): ] img2img_pipes = [ StableDiffusionImg2ImgPipeline.from_pretrained( + model_name, vae=txt2img_pipes[i].vae, text_encoder=txt2img_pipes[i].text_encoder, tokenizer=txt2img_pipes[i].tokenizer,