-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implement get history for image captions, change openai to gpt3.5 #39
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,7 @@ | |
from botocore.client import BaseClient | ||
from dependency_injector import containers, providers | ||
from dependency_injector.providers import Resource | ||
from langchain.llms import OpenAI | ||
from langchain.chat_models import ChatOpenAI | ||
from replicate import Client | ||
|
||
from app.infrastructure.auth0.auth0 import Auth0Service | ||
|
@@ -122,10 +122,10 @@ class Container(containers.DeclarativeContainer): | |
model_id=config.infrastructures.replicate.caption_model.model_id, | ||
caption_client=caption_client, | ||
) | ||
open_ai: OpenAI = providers.Singleton( | ||
OpenAI, | ||
open_ai: ChatOpenAI = providers.Singleton( | ||
ChatOpenAI, | ||
model_name=config.infrastructures.open_ai.model_name, | ||
openai_api_key=config.infrastructures.open_ai.openai_api_key, | ||
max_tokens=config.infrastructures.open_ai.max_tokens, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why remove this, this param should not be remove even using whatever API |
||
temperature=config.infrastructures.open_ai.temperature, | ||
) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,9 @@ | |
from datetime import datetime | ||
from typing import Generator | ||
|
||
from langchain import OpenAI, PromptTemplate | ||
from langchain import PromptTemplate | ||
from langchain.chat_models import ChatOpenAI | ||
from langchain.schema.messages import BaseMessageChunk | ||
|
||
|
||
class ChatGPTVocabularyGenerator: | ||
|
@@ -32,7 +34,7 @@ class ChatGPTVocabularyGenerator: | |
'{"{{ learning_language }}": {{ format_output }}, "{{ primary_language }}": {{ format_output }}}' | ||
) | ||
|
||
def __init__(self, model: OpenAI): | ||
def __init__(self, model: ChatOpenAI): | ||
self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}") | ||
self.model = model | ||
|
||
|
@@ -63,8 +65,11 @@ def generate_vocabulary_questions( | |
prompt = textwrap.dedent(prompt) | ||
self.logger.debug(f"Request: {prompt}") | ||
for text in self.model.stream(prompt): | ||
if isinstance(text, BaseMessageChunk): | ||
yield text.dict()["content"] | ||
elif isinstance(text, str): | ||
yield text | ||
Comment on lines
+68
to
+71
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this is going to work with the vocab generation, Did you test these code changes before committing? Be honest |
||
response += text | ||
yield text | ||
|
||
self.logger.debug(f"Execution time: {datetime.now() - start}") | ||
self.logger.debug(f"Response: {response}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems clear in the code that, You're not using anything assembled to the "Chat" function of the OpenAI at all, so here You can use the base OpenAI class enough, just update the model name