Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement get history for image captions, change openai to gpt3.5 #39

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions app/container/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from botocore.client import BaseClient
from dependency_injector import containers, providers
from dependency_injector.providers import Resource
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from replicate import Client

from app.infrastructure.auth0.auth0 import Auth0Service
Expand Down Expand Up @@ -122,10 +122,10 @@ class Container(containers.DeclarativeContainer):
model_id=config.infrastructures.replicate.caption_model.model_id,
caption_client=caption_client,
)
open_ai: OpenAI = providers.Singleton(
OpenAI,
open_ai: ChatOpenAI = providers.Singleton(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems clear in the code that, You're not using anything assembled to the "Chat" function of the OpenAI at all, so here You can use the base OpenAI class enough, just update the model name

ChatOpenAI,
model_name=config.infrastructures.open_ai.model_name,
openai_api_key=config.infrastructures.open_ai.openai_api_key,
max_tokens=config.infrastructures.open_ai.max_tokens,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove this, this param should not be remove even using whatever API

temperature=config.infrastructures.open_ai.temperature,
)

Expand Down
11 changes: 8 additions & 3 deletions app/infrastructure/llm/caption.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging
from typing import Generator

from langchain import OpenAI, PromptTemplate
from langchain import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.schema.messages import BaseMessageChunk


class ChatGPTCaptionGenerator:
Expand All @@ -18,7 +20,7 @@ class ChatGPTCaptionGenerator:
}
"""

def __init__(self, model: OpenAI):
def __init__(self, model: ChatOpenAI):
self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
self.model = model

Expand All @@ -35,4 +37,7 @@ def rewrite_caption(
description=caption,
)
for response in self.model.stream(prompt):
yield response
if isinstance(response, BaseMessageChunk):
yield response.dict()["content"]
elif isinstance(response, str):
yield response
11 changes: 8 additions & 3 deletions app/infrastructure/llm/vocabulary.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from datetime import datetime
from typing import Generator

from langchain import OpenAI, PromptTemplate
from langchain import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.schema.messages import BaseMessageChunk


class ChatGPTVocabularyGenerator:
Expand Down Expand Up @@ -32,7 +34,7 @@ class ChatGPTVocabularyGenerator:
'{"{{ learning_language }}": {{ format_output }}, "{{ primary_language }}": {{ format_output }}}'
)

def __init__(self, model: OpenAI):
def __init__(self, model: ChatOpenAI):
self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
self.model = model

Expand Down Expand Up @@ -63,8 +65,11 @@ def generate_vocabulary_questions(
prompt = textwrap.dedent(prompt)
self.logger.debug(f"Request: {prompt}")
for text in self.model.stream(prompt):
if isinstance(text, BaseMessageChunk):
yield text.dict()["content"]
elif isinstance(text, str):
yield text
Comment on lines +68 to +71
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is going to work with the vocab generation,

Did you test these code changes before committing? Be honest

response += text
yield text

self.logger.debug(f"Execution time: {datetime.now() - start}")
self.logger.debug(f"Response: {response}")
29 changes: 25 additions & 4 deletions app/repositories/caption_repository.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from contextlib import AbstractContextManager
from typing import Callable
from typing import Callable, List

from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, joinedload

from app.models.db.image_caption import (
ImageCaptionLearningLanguage,
Expand Down Expand Up @@ -35,6 +35,24 @@ def add_image_caption(
session.commit()
return img_caption

def get_caption(self, user_id: str) -> List[ImageCaptionPrimaryLanguage]:
with self.session_factory() as session:
query = (
session.query(ImageCaptionPrimaryLanguage)
.join(ImageCaptionLearningLanguage)
.options(
joinedload(ImageCaptionPrimaryLanguage.primary_language),
joinedload(ImageCaptionPrimaryLanguage.learning_caption).joinedload(
ImageCaptionLearningLanguage.learning_language
),
)
.filter(
ImageCaptionPrimaryLanguage.user_id == user_id,
)
.order_by(ImageCaptionPrimaryLanguage.time_created.desc())
)
return query.all()


class TranslatedCaptionRepository:
def __init__(
Expand All @@ -44,10 +62,13 @@ def __init__(
self.session_factory = session_factory

def add_translated_caption(
self, learning_caption: str, image_caption_object: ImageCaptionPrimaryLanguage
self,
learning_caption: str,
learning_language: str,
image_caption_object: ImageCaptionPrimaryLanguage,
) -> ImageCaptionLearningLanguage:
with self.session_factory() as session:
learning_language_id = check_language(session, learning_caption)
learning_language_id = check_language(session, learning_language)
translated_image_caption = ImageCaptionLearningLanguage(
learning_language_id=learning_language_id,
learning_language_caption=learning_caption,
Expand Down
9 changes: 9 additions & 0 deletions app/routes/api_v1/endpoints/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,12 @@ async def generate_caption(
),
media_type="text/plain",
)


@router.get("/caption")
@inject
async def get_caption_history(
auth: Auth0User = Depends(check_user),
caption_service: CaptionService = Depends(Provide[Container.caption_service]),
) -> object:
return caption_service.list_caption_history(user_id=auth.id)
7 changes: 6 additions & 1 deletion app/services/caption_service.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
from typing import Generator
from typing import Generator, List

from app.infrastructure.llm.caption import ChatGPTCaptionGenerator
from app.infrastructure.replicate.caption import CaptionGenerator
from app.models.db.image_caption import ImageCaptionPrimaryLanguage
from app.models.schemas.image_caption import ImageCaptionCreate
from app.repositories.caption_repository import (
CaptionRepository,
Expand Down Expand Up @@ -46,5 +47,9 @@ def get_caption_from_image(self, user_id: str, caption_input: dict) -> Generator
)
self.learning_caption_repository.add_translated_caption(
learning_caption=caption_data[caption_input["learning_language"]],
learning_language=caption_input["learning_language"],
image_caption_object=caption_insert_object,
)

def list_caption_history(self, user_id: str) -> List[ImageCaptionPrimaryLanguage]:
return self.primary_caption_repository.get_caption(user_id=user_id)
6 changes: 3 additions & 3 deletions app/tests/image_caption/test_image_caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest
from fastapi.testclient import TestClient
from langchain import OpenAI
from langchain.chat_models import ChatOpenAI
from replicate import Client

from app.infrastructure.auth0.auth0 import Auth0Service
Expand Down Expand Up @@ -65,7 +65,7 @@ def test_generate_caption(client):

primary_caption_repository_mock = mock.AsyncMock(spec=CaptionRepository)
learning_caption_repository_mock = mock.AsyncMock(spec=TranslatedCaptionRepository)
open_ai_mock = mock.Mock(spec=OpenAI)
open_ai_mock = mock.Mock(spec=ChatOpenAI)
open_ai_mock.stream.return_value = mock_chat_gpt_response()

replicate_caption_mock = mock.Mock(spec=Client)
Expand Down Expand Up @@ -110,7 +110,7 @@ def test_generate_caption_fail(client):

primary_caption_repository_mock = mock.AsyncMock(spec=CaptionRepository)
learning_caption_repository_mock = mock.AsyncMock(spec=TranslatedCaptionRepository)
open_ai_mock = mock.Mock(spec=OpenAI)
open_ai_mock = mock.Mock(spec=ChatOpenAI)
open_ai_mock.stream.return_value = mock_fail_response()

replicate_caption_mock = mock.Mock(spec=Client)
Expand Down
6 changes: 3 additions & 3 deletions app/tests/vocabulary/test_vocabulary.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest
from fastapi.testclient import TestClient
from langchain import OpenAI
from langchain.chat_models import ChatOpenAI

from app.main import app
from app.models.db.vocabulary import Category
Expand Down Expand Up @@ -149,7 +149,7 @@ def test_func_generate_vocabulary_questions(client):
auth_service_mock = mock_user()
voca_repo_mock = mock.Mock(spec=VocabularyRepository)

open_ai_mock = mock.Mock(spec=OpenAI)
open_ai_mock = mock.Mock(spec=ChatOpenAI)
open_ai_mock.stream.return_value = mock_chat_gpt_response()

app.container.auth.override(auth_service_mock)
Expand Down Expand Up @@ -190,7 +190,7 @@ def test_func_generate_vocabulary_questions_invalid_level_type(client):
def test_prompt_parse_failed(client):
auth_service_mock = mock_user()
voca_repo_mock = mock.Mock(spec=VocabularyRepository)
open_ai_mock = mock.Mock(spec=OpenAI)
open_ai_mock = mock.Mock(spec=ChatOpenAI)
open_ai_mock.stream.return_value = mock_wrong_chat_gpt_response()
app.container.auth.override(auth_service_mock)
app.container.open_ai.override(open_ai_mock)
Expand Down
2 changes: 1 addition & 1 deletion config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ core:

infrastructures:
open_ai:
model_name: ${OPENAI_MODEL_NAME:"gpt-3.5-turbo"}
openai_api_key: ${OPENAI_API_KEY}
max_tokens: ${OPENAI_MAX_TOKENS:"-1"}
temperature: ${OPENAI_TEMPERATURE:"0.5"}
aws:
access_key_id: ${AWS_ACCESS_KEY_ID}
Expand Down