Skip to content

Commit

Permalink
PowerPaint (#2076)
Browse files Browse the repository at this point in the history
* support powerpaint

* Update gradio_PowerPaint.py
  • Loading branch information
zhuang2002 authored Dec 6, 2023
1 parent c1873dd commit 5974178
Show file tree
Hide file tree
Showing 5 changed files with 3,737 additions and 0 deletions.
133 changes: 133 additions & 0 deletions projects/powerpaint/gradio_PowerPaint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import sys
import cv2
import torch
import numpy as np
import gradio as gr
from PIL import Image
from pathlib import Path
torch.set_grad_enabled(False)
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL
from pipeline.pipeline_PowerPaint import StableDiffusionInpaintPipeline as Pipeline
from utils.utils import *
pipe = Pipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16,safety_checker=None)
pipe.tokenizer = CLIPTokenizer.from_pretrained(
"runwayml/stable-diffusion-v1-5", subfolder="tokenizer", revision=None, torch_dtype=torch.float16
)
pipe.text_encoder = CLIPTextModel.from_pretrained(
"runwayml/stable-diffusion-v1-5", subfolder="text_encoder", revision=None,torch_dtype=torch.float16
)
pipe.vae = AutoencoderKL.from_pretrained(
"runwayml/stable-diffusion-v1-5", subfolder="vae", revision=None,torch_dtype=torch.float16
)

pipe.tokenizer = TokenizerWrapper(from_pretrained="runwayml/stable-diffusion-v1-5", subfolder="tokenizer", revision=None)
add_tokens(tokenizer = pipe.tokenizer,text_encoder = pipe.text_encoder,placeholder_tokens = ["MMcontext","MMshape","MMobject"],initialize_tokens = ["a","a","a"],num_vectors_per_token = 10)
pipe.unet.load_state_dict(torch.load("./models/diffusion_pytorch_model.bin"), strict=False)
pipe.text_encoder.load_state_dict(torch.load("./models/change_pytorch_model.bin"), strict=False)
pipe = pipe.to("cuda")

import random
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)

def add_task(prompt,negative_prompt,control_type):
if control_type == 'Object_removal':
promptA = prompt+" MMcontext"
promptB = prompt+" MMcontext"
negative_promptA = negative_prompt+" MMobject"
negative_promptB = negative_prompt+" MMobject"
elif control_type == 'Shape_object':
promptA = prompt+" MMshape"
promptB = prompt+" MMcontext"
negative_promptA = negative_prompt+" MMshape"
negative_promptB = negative_prompt+" MMcontext"
elif control_type == 'Object_inpaint':
promptA = prompt+" MMobject"
promptB = prompt+" MMobject"
negative_promptA = negative_prompt+" MMobject"
negative_promptB = negative_prompt+" MMobject"

return promptA,promptB,negative_promptA,negative_promptB

from PIL import Image, ImageFilter
def predict(input_image, mask_img, prompt,Fitting_degree, ddim_steps, scale, seed,negative_prompt,task):
promptA,promptB,negative_promptA,negative_promptB = add_task(prompt,negative_prompt,task)
input_image["mask"] = mask_img['image']
size1,size2 = input_image["image"].convert("RGB").size
if size1<size2:
input_image["image"] = input_image["image"].convert("RGB").resize((640,int(size2/size1*640)))
else:
input_image["image"] = input_image["image"].convert("RGB").resize((int(size1/size2*640),640))
img = np.array(input_image["image"].convert("RGB"))

W = int(np.shape(img)[0]-np.shape(img)[0]%8)
H = int(np.shape(img)[1]-np.shape(img)[1]%8)
input_image["image"] = input_image["image"].resize((H,W))
input_image["mask"] = input_image["mask"].resize((H,W))
set_seed(seed)
result = pipe(promptA=promptA,promptB = promptB, tradoff = Fitting_degree,tradoff_nag = Fitting_degree,negative_promptA = negative_promptA,negative_promptB = negative_promptB,image=input_image["image"].convert("RGB"), mask_image=input_image["mask"].convert("RGB"),width=H,height=W,guidance_scale = scale,num_inference_steps = ddim_steps).images[0]
mask_np = np.array(input_image["mask"].convert("RGB"))
red = np.array(result).astype('float')*1
red[:,:,0] = 0
red[:,:,2] = 180.0
red[:,:,1] = 0
result_m = np.array(result)
result_m = Image.fromarray((result_m.astype('float')*(1-mask_np.astype('float')/512.0)+mask_np.astype('float')/512.0*red).astype('uint8'))
m_img = input_image["mask"].convert("RGB").filter(ImageFilter.GaussianBlur(radius = 4))
m_img = np.asarray(m_img)/255.0
img_np = np.asarray(input_image["image"].convert("RGB"))/255.0
ours_np = np.asarray(result)/255.0
ours_np = ours_np*m_img+(1-m_img)*img_np
result_paste = Image.fromarray(np.uint8(ours_np*255))

dict_res = [input_image["mask"].convert("RGB"),result_m,result,result_paste]


return dict_res

block = gr.Blocks().queue()
with block:
with gr.Row():
gr.Markdown("## PowerPaint")

with gr.Row():
with gr.Column():
input_image = gr.Image(source='upload', tool='sketch', type="pil")
mask_img = gr.Image(source='upload', tool='sketch', type="pil")
prompt = gr.Textbox(label="Prompt")
negative_prompt = gr.Textbox(label="negative_prompt")
run_button = gr.Button(label="Run")
with gr.Accordion("Advanced options", open=False):
control_type = gr.Radio(['Object_inpaint', 'Shape_object', 'Object_removal'])
ddim_steps = gr.Slider(label="Steps", minimum=1,
maximum=50, value=45, step=1)
scale = gr.Slider(
label="Guidance Scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1
)
Fitting_degree = gr.Slider(
label="Fitting degree",
minimum=0,
maximum=1,
step=0.05,
randomize=True,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=2147483647,
step=1,
randomize=True,
)
with gr.Column():
gallery = gr.Gallery(label="Generated images", show_label=False).style(
grid=[2], height="auto")

run_button.click(fn=predict, inputs=[
input_image,mask_img, prompt,Fitting_degree, ddim_steps, scale, seed,negative_prompt,control_type], outputs=[gallery])
block.launch(share = True,server_name="0.0.0.0",server_port=9586)
204 changes: 204 additions & 0 deletions projects/powerpaint/gradio_PowerPaint_ControlNet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import sys
import cv2
import torch
import numpy as np
import gradio as gr
from diffusers.utils import load_image
from PIL import Image, ImageFilter

torch.set_grad_enabled(False)

from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL
from pipeline.pipeline_PowerPaint_ControlNet import StableDiffusionControlNetInpaintPipeline as Pipeline
from diffusers.pipelines.controlnet.pipeline_controlnet import ControlNetModel
from utils.utils import *

weight_dtype = torch.float16
controlnet_conditioning_scale = 0.5
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-canny", torch_dtype=weight_dtype
)
controlnet = controlnet.to("cuda")
global pipe
pipe = Pipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16,safety_checker=None)


pipe.tokenizer = CLIPTokenizer.from_pretrained(
"runwayml/stable-diffusion-v1-5", subfolder="tokenizer", revision=None, torch_dtype=torch.float16
)
pipe.text_encoder = CLIPTextModel.from_pretrained(
"runwayml/stable-diffusion-v1-5", subfolder="text_encoder", revision=None,torch_dtype=torch.float16
)
pipe.vae = AutoencoderKL.from_pretrained(
"runwayml/stable-diffusion-v1-5", subfolder="vae", revision=None,torch_dtype=torch.float16
)

pipe.tokenizer = TokenizerWrapper(from_pretrained="runwayml/stable-diffusion-v1-5", subfolder="tokenizer", revision=None)
add_tokens(tokenizer = pipe.tokenizer,text_encoder = pipe.text_encoder,placeholder_tokens = ["MMcontext","MMshape","MMobject"],initialize_tokens = ["a","a","a"],num_vectors_per_token = 10)
pipe.unet.load_state_dict(torch.load("./models/diffusion_pytorch_model.bin"), strict=False)
pipe.text_encoder.load_state_dict(torch.load("./models/change_pytorch_model.bin"), strict=False)
pipe = pipe.to("cuda")


import random
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)

from transformers import DPTFeatureExtractor, DPTForDepthEstimation
depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")

def get_depth_map(image):
image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
with torch.no_grad(), torch.autocast("cuda"):
depth_map = depth_estimator(image).predicted_depth

depth_map = torch.nn.functional.interpolate(
depth_map.unsqueeze(1),
size=(1024, 1024),
mode="bicubic",
align_corners=False,
)
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
depth_map = (depth_map - depth_min) / (depth_max - depth_min)
image = torch.cat([depth_map] * 3, dim=1)

image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
return image

global current_control
current_control = 'canny'
from controlnet_aux import HEDdetector
from controlnet_aux import OpenposeDetector
openpose = OpenposeDetector.from_pretrained('lllyasviel/ControlNet')
hed = HEDdetector.from_pretrained('lllyasviel/ControlNet')

def predict(input_image, input_control_image, control_type, prompt, ddim_steps, scale, seed,negative_prompt):
promptA = prompt+" MMobject"
promptB = prompt+" MMobject"
negative_promptA = negative_prompt+" MMobject"
negative_promptB = negative_prompt+" MMobject"
img = np.array(input_image["image"].convert("RGB"))
W = int(np.shape(img)[0]-np.shape(img)[0]%8)
H = int(np.shape(img)[1]-np.shape(img)[1]%8)
input_image["image"] = input_image["image"].resize((H,W))
input_image["mask"] = input_image["mask"].resize((H,W))
print(np.shape(np.array(input_image["mask"].convert("RGB"))))
print(np.shape(np.array(input_image["image"].convert("RGB"))))

global current_control
global pipe
if current_control != control_type:
if control_type == 'canny' or control_type is None:
pipe.controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=weight_dtype)
elif control_type == 'pose':
pipe.controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-openpose", torch_dtype=weight_dtype)
elif control_type == 'depth':
pipe.controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-depth", torch_dtype=weight_dtype)
else:
pipe.controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-hed", torch_dtype=weight_dtype)
pipe = pipe.to("cuda")
current_control = control_type

controlnet_image = input_control_image
if current_control == 'canny':
controlnet_image = controlnet_image.resize((H, W))
controlnet_image = np.array(controlnet_image)
controlnet_image = cv2.Canny(controlnet_image, 100, 200)
controlnet_image = controlnet_image[:, :, None]
controlnet_image = np.concatenate([controlnet_image, controlnet_image, controlnet_image], axis=2)
controlnet_image = Image.fromarray(controlnet_image)
elif current_control == 'pose':
controlnet_image = openpose(controlnet_image)
elif current_control == 'depth':
controlnet_image = controlnet_image.resize((H, W))
controlnet_image = get_depth_map(controlnet_image)
else:
controlnet_image = hed(controlnet_image)

mask_np = np.array(input_image["mask"].convert("RGB"))
controlnet_image= controlnet_image.resize((H,W))
controlnet_np = np.array(controlnet_image)
set_seed(seed)
result = pipe(promptA=promptB,
promptB = promptA,
tradoff = 1.0,
tradoff_nag = 1.0,
negative_promptA = negative_promptA,
negative_promptB = negative_promptB,
image=input_image["image"].convert("RGB"),
mask_image=input_image["mask"].convert("RGB"),
control_image=controlnet_image,
width=H,
height=W,
guidance_scale = scale,
num_inference_steps = ddim_steps).images[0]
red = np.array(result).astype('float')*1
red[:,:,0] = 180.0
red[:,:,2] = 0
red[:,:,1] = 0
result_m = np.array(result)
result_m = Image.fromarray((result_m.astype('float')*(1-mask_np.astype('float')/512.0)+mask_np.astype('float')/512.0*red).astype('uint8'))

controlnet_mask = Image.fromarray((np.array(input_image["image"].convert("RGB")).astype('float')*(1-mask_np.astype('float')/255.0)+mask_np.astype('float')/255.0*controlnet_np).astype('uint8'))

mask_np = np.array(input_image["mask"].convert("RGB"))
m_img = input_image["mask"].convert("RGB").filter(ImageFilter.GaussianBlur(radius = 4))
m_img = np.asarray(m_img)/255.0
img_np = np.asarray(input_image["image"].convert("RGB"))/255.0
ours_np = np.asarray(result)/255.0
ours_np = ours_np*m_img+(1-m_img)*img_np
result_paste = Image.fromarray(np.uint8(ours_np*255))
return [input_image["mask"].convert("RGB"),result_m,result,result_paste], [controlnet_image,controlnet_mask]


block = gr.Blocks().queue()
with block:
with gr.Row():
gr.Markdown("## PowerPaint with ControlNet")

with gr.Row():
with gr.Column():
gr.Markdown("## Input image")
input_image = gr.Image(source='upload', tool='sketch', type="pil")
gr.Markdown("## Input control image")
input_control_image = gr.Image(source='upload', type="pil")
gr.Markdown("### Input control image")
control_type = gr.Radio(['canny', 'pose', 'depth', 'hed'])
promptA = gr.Textbox(label="Prompt")
negative_promptA = gr.Textbox(label="negative_prompt")
run_button = gr.Button(label="Run")
with gr.Accordion("Advanced options", open=False):
ddim_steps = gr.Slider(label="Steps", minimum=1,
maximum=50, value=45, step=1)
scale = gr.Slider(
label="Guidance Scale", minimum=0.1, maximum=30.0, value=10, step=0.1
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=2147483647,
step=1,
randomize=True,
)
with gr.Column():
gallery = gr.Gallery(label="Generated images", show_label=False).style(
grid=[2], height="auto")
control_image_show = gr.Gallery(label="Control condition", show_label=False).style(
grid=[2], height="auto")


run_button.click(fn=predict, inputs=[
input_image, input_control_image, control_type, promptA, ddim_steps, scale, seed,negative_promptA],
outputs=[gallery, control_image_show])


block.launch(share = True,server_name="0.0.0.0",server_port=9586)
Loading

0 comments on commit 5974178

Please sign in to comment.