Skip to content

Commit

Permalink
Allow SVGVariations to contain more than one conversation
Browse files Browse the repository at this point in the history
  • Loading branch information
opcode81 committed Jun 27, 2024
1 parent 27597bf commit 8b89381
Showing 1 changed file with 32 additions and 4 deletions.
36 changes: 32 additions & 4 deletions src/penai/variations/svg_variations.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,33 @@ def __init__(
self,
original_svg: SVG,
variations_dict: dict[str, str],
conversation: SVGVariationsConversation | None = None,
conversation: SVGVariationsConversation | list[SVGVariationsConversation] | None = None,
):
""":param original_svg: the original SVG
:param variations_dict: a mapping from variation name to SVG code
"""
self.variations_dict = variations_dict
self.original_svg = original_svg
self.conversation = conversation
if conversation is None:
conversation = []
elif isinstance(conversation, SVGVariationsConversation):
conversation = [conversation]
self._conversations = conversation

@property
def conversation(self) -> SVGVariationsConversation | None:
"""Returns the main conversation (if there is only one).
:return: the main conversation
"""
if len(self._conversations) == 1:
return self._conversations[0]
else:
return None

def conversations(self) -> list[SVGVariationsConversation]:
""":return: the list of all conversations (may be empty)"""
return self._conversations

def iter_variations_name_svg(self) -> Iterator[tuple[str, SVG]]:
for name, svg_text in self.variations_dict.items():
Expand All @@ -176,7 +195,7 @@ def revise(
preprompt: str = REVISION_PREPROMPT,
) -> "SVGVariations":
if self.conversation is None:
raise ValueError("Cannot revise without a conversation")
raise ValueError("Cannot revise without a (single main) conversation")
conversation = self.conversation.clone()
revision_prompt = preprompt + revision_logic
response = conversation.query(revision_prompt)
Expand All @@ -190,6 +209,13 @@ def write_results(self, result_writer: ResultWriter, file_prefix: str = "") -> N
self.conversation.get_full_conversation_string(),
content_description="full conversation",
)
elif len(self._conversations) > 1:
for i, conversation in enumerate(self._conversations, start=1):
result_writer.write_text_file(
f"{file_prefix}conversation_{i}.md",
conversation.get_full_conversation_string(),
content_description=f"conversation {i}",
)
result_writer.write_text_file(
f"{file_prefix}variations.html",
self.to_html(),
Expand Down Expand Up @@ -465,6 +491,7 @@ def create_variations_from_example(
)

variations_dict = {}
conversations = []
for _i, (name, svg_text) in enumerate(example_variations.variations_dict.items()):
conversation = self._create_conversation(system_prompt=system_prompt)
prompt = (
Expand All @@ -481,8 +508,9 @@ def create_variations_from_example(
if len(code_snippets) > 1:
log.warning("Received more than one code snippet in response; using the first one")
variations_dict[name] = code_snippets[0].code
conversations.append(conversation)

variations = SVGVariations(self.svg, variations_dict)
variations = SVGVariations(self.svg, variations_dict, conversations)
variations.write_results(self.result_writer)
self.result_writer.write_text_file("example_presented.html", example_variations.to_html())
return variations

0 comments on commit 8b89381

Please sign in to comment.