Skip to content

Commit

Permalink
feat(bing): support draw image from text prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
lss233 committed Apr 26, 2023
1 parent 11e7a10 commit 7fdf3c1
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 26 deletions.
77 changes: 52 additions & 25 deletions adapter/ms/bing.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import json
from io import BytesIO
from typing import Generator, Union, List

import aiohttp
import asyncio
from PIL import Image

from constants import config
from adapter.botservice import BotAdapter
from EdgeGPT import Chatbot as EdgeChatbot, ConversationStyle
from EdgeGPT import Chatbot as EdgeChatbot, ConversationStyle, NotAllowedToAccess
from contextlib import suppress

from constants import botManager
from drawing import DrawingAPI
Expand All @@ -15,6 +19,9 @@
from ImageGen import ImageGenAsync
from graia.ariadne.message.element import Image as GraiaImage

image_pattern = r"!\[.*\]\((.*)\)"


class BingAdapter(BotAdapter, DrawingAPI):
cookieData = None
count: int = 0
Expand Down Expand Up @@ -49,52 +56,69 @@ async def on_reset(self):
async def ask(self, prompt: str) -> Generator[str, None, None]:
self.count = self.count + 1
parsed_content = ''
image_urls = []
try:
async for final, response in self.bot.ask_stream(prompt=prompt,
conversation_style=self.conversation_style,
wss_link=config.bing.wss_link):
if not response:
continue
if not final:
response = re.sub(r"\[\^\d+\^\]", "", response)
if config.bing.show_references:
response = re.sub(r"\[(\d+)\]: ", r"\1: ", response)
else:
response = re.sub(r"(\[\d+\]\: .+)+", "", response)
parsed_content = response

else:
try:
if final:
# 最后一条消息
max_messages = config.bing.max_messages
with suppress(KeyError):
max_messages = response["item"]["throttling"]["maxNumUserMessagesInConversation"]
except Exception:
max_messages = config.bing.max_messages
if config.bing.show_remaining_count:
remaining_conversations = f'\n剩余回复数:{self.count} / {max_messages} '
else:
remaining_conversations = ''

with suppress(KeyError):
raw_text = response["item"]["messages"][1]["adaptiveCards"][0]["body"][0]["text"]
image_urls = re.findall(image_pattern, raw_text)

remaining_conversations = f'\n剩余回复数:{self.count} / {max_messages} ' \
if config.bing.show_remaining_count else ''

if len(response["item"].get('messages', [])) > 1 and config.bing.show_suggestions:
suggestions = response["item"]["messages"][-1].get("suggestedResponses", [])
if len(suggestions) > 0:
parsed_content = parsed_content + '\n猜你想问: \n'
for suggestion in suggestions:
parsed_content = f"{parsed_content}* {suggestion.get('text')} \n"
yield parsed_content

parsed_content = parsed_content + remaining_conversations
# not final的parsed_content已经yield走了,只能在末尾加剩余回复数,或者改用EdgeGPT自己封装的ask之后再正则替换

if parsed_content == remaining_conversations: # No content
yield "Bing 已结束本次会话。继续发送消息将重新开启一个新会话。"
await self.on_reset()
return
else:
# 生成中的消息
parsed_content = re.sub(r"\[\^\d+\^\]", "", response)
if config.bing.show_references:
parsed_content = re.sub(r"\[(\d+)\]: ", r"\1: ", parsed_content)
else:
parsed_content = re.sub(r"(\[\d+\]\: .+)+", "", parsed_content)
parts = re.split(image_pattern, parsed_content)
# 图片单独保存
parsed_content = parts[0]

if len(parts) > 2:
parsed_content = parsed_content + parts[-1]

yield parsed_content
logger.debug(f"[Bing AI 响应] {parsed_content}")
image_tasks = [
asyncio.create_task(self.__download_image(url))
for url in image_urls
]
for image in await asyncio.gather(*image_tasks):
yield image
except (asyncio.exceptions.TimeoutError, asyncio.exceptions.CancelledError) as e:
raise e
except Exception as e:
logger.exception(e)
yield "Bing 已结束本次会话。继续发送消息将重新开启一个新会话。"
await self.on_reset()
except NotAllowedToAccess:
yield "出现错误:机器人的 Bing Cookie 可能已过期,或者机器人当前使用的 IP 无法使用 Bing AI。"
return
except Exception as e:
raise e

async def text_to_img(self, prompt: str):
logger.debug(f"[Bing Image] Prompt: {prompt}")
Expand All @@ -112,11 +136,14 @@ async def text_to_img(self, prompt: str):
async def img_to_img(self, init_images: List[GraiaImage], prompt=''):
return await self.text_to_img(prompt)

async def __download_image(self, url):
async def __download_image(self, url) -> GraiaImage:
logger.debug(f"[Bing AI] 下载图片:{url}")

async with aiohttp.ClientSession() as session:
async with session.get(url, proxy=self.bot.proxy) as resp:
if resp.status == 200:
return GraiaImage(data_bytes=await resp.read())
resp.raise_for_status()
logger.debug(f"[Bing AI] 下载完成:{resp.content_type} {url}")
return GraiaImage(data_bytes=await resp.read())

async def preset_ask(self, role: str, text: str):
yield None # Bing 不使用预设功能
3 changes: 2 additions & 1 deletion platforms/onebot_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,13 @@ def transform_message_chain(text: str) -> MessageChain:
def transform_from_message_chain(chain: MessageChain):
result = ''
for elem in chain:
if isinstance(elem, Image):
if isinstance(elem, (Image, GraiaImage)):
result = result + MessageSegment.image(f"base64://{elem.base64}")
elif isinstance(elem, Plain):
result = result + MessageSegment.text(str(elem))
elif isinstance(elem, Voice):
result = result + MessageSegment.record(f"base64://{elem.base64}")
logger.debug(result)
return result


Expand Down

0 comments on commit 7fdf3c1

Please sign in to comment.