Skip to content
This repository has been archived by the owner on Aug 10, 2023. It is now read-only.

Commit

Permalink
fix async bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Antonio committed May 24, 2023
1 parent 76a39eb commit 089e8db
Showing 1 changed file with 53 additions and 10 deletions.
63 changes: 53 additions & 10 deletions src/BingImageCreator.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,31 @@ class ImageGenAsync:
auth_cookie: str
"""

def __init__(self, auth_cookie: str, quiet: bool = False) -> None:
def __init__(
self,
auth_cookie: str = None,
debug_file: Union[str, None] = None,
quiet: bool = False,
all_cookies: List[Dict] = None,
) -> None:
if auth_cookie is None and not all_cookies:
raise Exception("No auth cookie provided")
self.session = aiohttp.ClientSession(
headers=HEADERS,
cookies={"_U": auth_cookie},
trust_env=True,
)
if all_cookies:
for cookie in all_cookies:
self.session.cookie_jar.update_cookies(
{cookie["name"]: cookie["value"]}
)
if auth_cookie:
self.session.cookie_jar.update_cookies({"_U": auth_cookie})
self.quiet = quiet
self.debug_file = debug_file
if self.debug_file:
self.debug = partial(debug, self.debug_file)

async def __aenter__(self):
return self
Expand Down Expand Up @@ -298,23 +316,30 @@ async def get_images(self, prompt: str) -> list:
raise Exception("No images")
return normal_image_links

async def save_images(self, links: list, output_dir: str) -> None:
async def save_images(
self, links: list, output_dir: str, file_name: str = None
) -> None:
"""
Saves images to output directory
"""
if self.debug_file:
self.debug(download_message)
if not self.quiet:
print("\nDownloading images...")
print(download_message)
with contextlib.suppress(FileExistsError):
os.mkdir(output_dir)
try:
fn = f"{file_name}_" if file_name else ""
jpeg_index = 0
for link in links:
while os.path.exists(os.path.join(output_dir, f"{jpeg_index}.jpeg")):
while os.path.exists(
os.path.join(output_dir, f"{fn}{jpeg_index}.jpeg")
):
jpeg_index += 1
async with self.session.get(link, raise_for_status=True) as response:
# save response to file
with open(
os.path.join(output_dir, f"{jpeg_index}.jpeg"), "wb"
os.path.join(output_dir, f"{fn}{jpeg_index}.jpeg"), "wb"
) as output_file:
async for chunk in response.content.iter_chunked(8192):
output_file.write(chunk)
Expand All @@ -324,10 +349,19 @@ async def save_images(self, links: list, output_dir: str) -> None:
) from url_exception


async def async_image_gen(args) -> None:
async with ImageGenAsync(args.U, args.quiet) as image_generator:
images = await image_generator.get_images(args.prompt)
await image_generator.save_images(images, output_dir=args.output_dir)
async def async_image_gen(
prompt: str,
output_dir: str,
u_cookie=None,
debug_file=None,
quiet=False,
all_cookies=None,
):
async with ImageGenAsync(
u_cookie, debug_file=debug_file, quiet=quiet, all_cookies=all_cookies
) as image_generator:
images = await image_generator.get_images(prompt)
await image_generator.save_images(images, output_dir=output_dir)


def main():
Expand Down Expand Up @@ -396,7 +430,16 @@ def main():
output_dir=args.output_dir,
)
else:
asyncio.run(async_image_gen(args))
asyncio.run(
async_image_gen(
args.prompt,
args.output_dir,
args.U,
args.debug_file,
args.quiet,
all_cookies=cookie_json,
)
)


if __name__ == "__main__":
Expand Down

0 comments on commit 089e8db

Please sign in to comment.