Skip to content

Commit

Permalink
pass model name to img2img
Browse files Browse the repository at this point in the history
  • Loading branch information
JEMeyer committed May 9, 2024
1 parent 4242a41 commit 2842d34
Showing 1 changed file with 3 additions and 14 deletions.
17 changes: 3 additions & 14 deletions stablediffusion_fastapi_multigpu/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,9 @@ async def log_duration(request: Request, call_next):
# Ensure model is on GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = os.environ.get("MODEL_NAME", "stabilityai/sdxl-turbo")
pipe_txt2img = StableDiffusionPipeline.from_pretrained(
"stabilityai/sdxl-turbo", torch_dtype=torch.float16
)
pipe_img2img = StableDiffusionImg2ImgPipeline(
vae=pipe_txt2img.vae,
text_encoder=pipe_txt2img.text_encoder,
tokenizer=pipe_txt2img.tokenizer,
unet=pipe_txt2img.unet,
scheduler=pipe_txt2img.scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)

# Load model and assign it to available GPUs
num_gpus = 1 # torch.cuda.device_count()
num_gpus = torch.cuda.device_count()
txt2img_pipes = [
StableDiffusionPipeline.from_pretrained(
model_name, torch_dtype=torch.float16, variant="fp16"
Expand All @@ -56,6 +44,7 @@ async def log_duration(request: Request, call_next):
]
img2img_pipes = [
StableDiffusionImg2ImgPipeline.from_pretrained(
model_name,
vae=txt2img_pipes[i].vae,
text_encoder=txt2img_pipes[i].text_encoder,
tokenizer=txt2img_pipes[i].tokenizer,
Expand Down

0 comments on commit 2842d34

Please sign in to comment.