Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature/fast-embeddings #405

Merged
merged 9 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ updates:
- "patch"
ignore:
- dependency-name: faststream
- dependency-name: torch
8 changes: 1 addition & 7 deletions infrastructure/aws/ecs.tf
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,6 @@ module "worker" {
prefix = "redbox"
ecr_repository_uri = "${var.ecr_repository_uri}/redbox-worker"
ecs_cluster_id = module.cluster.ecs_cluster_id
health_check = {
healthy_threshold = 3
unhealthy_threshold = 3
accepted_response = "200"
path = "/health"
timeout = 5
}
state_bucket = var.state_bucket
vpc_id = data.terraform_remote_state.vpc.outputs.vpc_id
private_subnets = data.terraform_remote_state.vpc.outputs.private_subnets
Expand All @@ -168,6 +161,7 @@ module "worker" {
aws_lb_arn = module.load_balancer.alb_arn
ip_whitelist = var.external_ips
environment_variables = local.environment_variables
http_healthcheck = false
}


Expand Down
470 changes: 217 additions & 253 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ boto3 = "^1.34.106"
pydantic-settings = "^2.2.1"
sentence-transformers = "^2.6.0"
unstructured = {version = "0.13.7", extras = ["all-docs"]}
torch = "2.3.0"
torch = "2.2.2"


[tool.poetry.group.api.dependencies]
Expand Down
5 changes: 4 additions & 1 deletion worker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ ADD worker/ /app/worker
ADD download_embedder.py /app/
ADD worker/download_ocr_models.py /app/

ADD worker/health.sh /app/
RUN chmod +x /app/health.sh

# Download the model
RUN python download_embedder.py --embedding_model ${EMBEDDING_MODEL}

Expand All @@ -44,4 +47,4 @@ RUN python -m nltk.downloader averaged_perceptron_tagger
# Download the OCR models
RUN python download_ocr_models.py

CMD ["uvicorn", "worker.src.app:app", "--host", "0.0.0.0", "--port", "5000"]
CMD ["faststream", "run", "embedder.src.worker:app", "--workers", "3"]
4 changes: 4 additions & 0 deletions worker/health.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash

echo "I am healthy"
exit 0
59 changes: 20 additions & 39 deletions worker/src/app.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#!/usr/bin/env python

import logging
from contextlib import asynccontextmanager
from datetime import datetime

from fastapi import Depends, FastAPI
from faststream.redis.fastapi import RedisRouter
from faststream import Context, ContextRepo, FastStream
from faststream.redis import RedisBroker

from redbox.model_db import SentenceTransformerDB
from redbox.models import Chunk, EmbedQueueItem, File, Settings, StatusResponse
from redbox.models import Chunk, EmbedQueueItem, File, Settings
from redbox.parsing import chunk_file
from redbox.storage.elasticsearch import ElasticsearchStorageHandler

Expand All @@ -18,26 +19,27 @@
env = Settings()


router = RedisRouter(url=env.redis_url)
broker = RedisBroker(url=env.redis_url)

publisher = router.broker.publisher(env.embed_queue_name)
publisher = broker.publisher(env.embed_queue_name)


def get_storage_handler():
@asynccontextmanager
async def lifespan(context: ContextRepo):
es = env.elasticsearch_client()
return ElasticsearchStorageHandler(es_client=es, root_index="redbox-data")
storage_handler = ElasticsearchStorageHandler(es_client=es, root_index="redbox-data")
model = SentenceTransformerDB(env.embedding_model)

context.set_global("storage_handler", storage_handler)
context.set_global("model", model)

def get_model() -> SentenceTransformerDB:
model = SentenceTransformerDB(env.embedding_model)
return model
yield


@router.subscriber(channel=env.ingest_queue_name)
@broker.subscriber(channel=env.ingest_queue_name)
async def ingest(
file: File,
storage_handler: ElasticsearchStorageHandler = Depends(get_storage_handler),
# embedding_model: SentenceTransformerDB = Depends(get_model),
storage_handler: ElasticsearchStorageHandler = Context(),
):
"""
1. Chunks file
Expand All @@ -63,45 +65,24 @@ async def ingest(
return items


@router.subscriber(channel=env.embed_queue_name)
@broker.subscriber(channel=env.embed_queue_name)
async def embed(
queue_item: EmbedQueueItem,
storage_handler: ElasticsearchStorageHandler = Depends(get_storage_handler),
embedding_model: SentenceTransformerDB = Depends(get_model),
storage_handler: ElasticsearchStorageHandler = Context(),
model: SentenceTransformerDB = Context(),
):
"""
1. embed queue-item text
2. update related chunk on ES
"""

chunk: Chunk = storage_handler.read_item(queue_item.chunk_uuid, "Chunk")
embedded_sentences = embedding_model.embed_sentences([chunk.text])
embedded_sentences = model.embed_sentences([chunk.text])
if len(embedded_sentences.data) != 1:
logging.error("expected just 1 embedding but got %s", len(embedded_sentences.data))
return
chunk.embedding = embedded_sentences.data[0].embedding
storage_handler.update_item(chunk)


app = FastAPI(lifespan=router.lifespan_context)
app.include_router(router)


@app.get("/health", tags=["health"])
def health() -> StatusResponse:
"""Returns the health of the API

Returns:
StatusResponse: The health of the API
"""

uptime = datetime.now() - start_time
uptime_seconds = uptime.total_seconds()

output = StatusResponse(
status="ready",
uptime_seconds=uptime_seconds,
version=app.version,
)

return output
app = FastStream(broker=broker, lifespan=lifespan)
18 changes: 4 additions & 14 deletions worker/tests/test_app.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import pytest
from faststream.redis import TestRedisBroker
from faststream.redis import TestApp, TestRedisBroker

from redbox.storage import ElasticsearchStorageHandler
from worker.src.app import env, router
from worker.src.app import app, broker, env


@pytest.mark.asyncio
Expand All @@ -19,7 +19,7 @@ async def test_ingest_file(s3_client, es_client, embedding_model, file):

storage_handler.write_item(file)

async with TestRedisBroker(router.broker) as br:
async with TestRedisBroker(broker) as br, TestApp(app):
await br.publish(file, channel=env.ingest_queue_name)

file = storage_handler.read_item(
Expand All @@ -40,18 +40,8 @@ async def test_embed_item_callback(elasticsearch_storage_handler, embed_queue_it
unembedded_chunk = elasticsearch_storage_handler.read_item(embed_queue_item.chunk_uuid, "Chunk")
assert unembedded_chunk.embedding is None

async with TestRedisBroker(router.broker) as br:
async with TestRedisBroker(broker) as br:
await br.publish(embed_queue_item, channel=env.embed_queue_name)

embedded_chunk = elasticsearch_storage_handler.read_item(embed_queue_item.chunk_uuid, "Chunk")
assert embedded_chunk.embedding is not None


def test_get_health(app_client):
"""
Given that the app is running
When I call /health
I Expect to see the docs
"""
response = app_client.get("/health")
assert response.status_code == 200
Loading