Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add allow specify the task id and get the location of task in the queue of pending task #14314

Merged
merged 4 commits into from
Dec 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions modules/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import piexif
import piexif.helper
from contextlib import closing

from modules.progress import create_task_id, add_task_to_queue, start_task, finish_task, current_task

def script_name_to_index(name, scripts):
try:
Expand Down Expand Up @@ -337,6 +337,9 @@ def init_script_args(self, request, default_script_args, selectable_scripts, sel
return script_args

def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
task_id = create_task_id("text2img")
if txt2imgreq.force_task_id is None:
task_id = txt2imgreq.force_task_id
script_runner = scripts.scripts_txt2img
if not script_runner.scripts:
script_runner.initialize_scripts(False)
Expand All @@ -363,6 +366,8 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
send_images = args.pop('send_images', True)
args.pop('save_images', None)

add_task_to_queue(task_id)

with self.queue_lock:
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
p.is_api = True
Expand All @@ -372,12 +377,14 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):

try:
shared.state.begin(job="scripts_txt2img")
start_task(task_id)
if selectable_scripts is not None:
p.script_args = script_args
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
else:
p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p)
finish_task(task_id)
finally:
shared.state.end()
shared.total_tqdm.clear()
Expand All @@ -387,6 +394,10 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())

def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
task_id = create_task_id("img2img")
if img2imgreq.force_task_id is None:
task_id = img2imgreq.force_task_id

init_images = img2imgreq.init_images
if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found")
Expand Down Expand Up @@ -423,6 +434,8 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
send_images = args.pop('send_images', True)
args.pop('save_images', None)

add_task_to_queue(task_id)

with self.queue_lock:
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
p.init_images = [decode_base64_to_image(x) for x in init_images]
Expand All @@ -433,12 +446,14 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):

try:
shared.state.begin(job="scripts_img2img")
start_task(task_id)
if selectable_scripts is not None:
p.script_args = script_args
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
else:
p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p)
finish_task(task_id)
finally:
shared.state.end()
shared.total_tqdm.clear()
Expand Down Expand Up @@ -514,7 +529,7 @@ def progressapi(self, req: models.ProgressRequest = Depends()):
if shared.state.current_image and not req.skip_current_image:
current_image = encode_pil_to_base64(shared.state.current_image)

return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo, current_task=current_task)

def interrogateapi(self, interrogatereq: models.InterrogateRequest):
image_b64 = interrogatereq.image
Expand Down
2 changes: 2 additions & 0 deletions modules/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def generate_model(self):
{"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}},
{"key": "force_task_id", "type": str, "default": None},
]
).generate_model()

Expand All @@ -126,6 +127,7 @@ def generate_model(self):
{"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}},
{"key": "force_task_id", "type": str, "default": None},
]
).generate_model()

Expand Down
2 changes: 2 additions & 0 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
hr_sampler_name: str = None
hr_prompt: str = ''
hr_negative_prompt: str = ''
force_task_id: str = None

cached_hr_uc = [None, None]
cached_hr_c = [None, None]
Expand Down Expand Up @@ -1358,6 +1359,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
inpainting_mask_invert: int = 0
initial_noise_multiplier: float = None
latent_mask: Image = None
force_task_id: str = None

image_mask: Any = field(default=None, init=False)

Expand Down
21 changes: 19 additions & 2 deletions modules/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@
from modules.shared import opts

import modules.shared as shared

from collections import OrderedDict
import string
import random
from typing import List

current_task = None
pending_tasks = {}
pending_tasks = OrderedDict()
finished_tasks = []
recorded_results = []
recorded_results_limit = 2
Expand All @@ -34,6 +37,11 @@ def finish_task(id_task):
if len(finished_tasks) > 16:
finished_tasks.pop(0)

def create_task_id(task_type):
N = 7
res = ''.join(random.choices(string.ascii_uppercase +
string.digits, k=N))
return f"task({task_type}-{res})"

def record_results(id_task, res):
recorded_results.append((id_task, res))
Expand All @@ -44,6 +52,9 @@ def record_results(id_task, res):
def add_task_to_queue(id_job):
pending_tasks[id_job] = time.time()

class PendingTasksResponse(BaseModel):
size: int = Field(title="Pending task size")
tasks: List[str] = Field(title="Pending task ids")

class ProgressRequest(BaseModel):
id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
Expand All @@ -63,8 +74,14 @@ class ProgressResponse(BaseModel):


def setup_progress_api(app):
app.add_api_route("/internal/pendingTasks", get_pending_tasks, methods=["GET"])
return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)

def get_pending_tasks():
pending_tasks_ids = list(pending_tasks)
pending_len = len(pending_tasks_ids)
return PendingTasksResponse(size=pending_len, tasks=pending_tasks_ids)


def progressapi(req: ProgressRequest):
active = req.id_task == current_task
Expand Down