Skip to content

Commit

Permalink
bugfix/compute-full-token-count-for-user-message
Browse files Browse the repository at this point in the history
  • Loading branch information
gecBurton committed Feb 7, 2025
1 parent 5ec4d89 commit 666c63c
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 12 deletions.
14 changes: 13 additions & 1 deletion django_app/redbox_app/redbox_core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,10 +654,22 @@ class Role(models.TextChoices):
def __str__(self) -> str: # pragma: no cover
return textwrap.shorten(self.text, width=20, placeholder="...")

@property
def associated_file_token_count(self):
"""count token of all files created before this chat
that would have been used in the creation of this message
"""
if self.role == self.Role.ai:
return 0

return self.chat.file_set.filter(
created_at__lt=datetime.now(tz=utc),
).aggregate(Sum("token_count"))["token_count__sum"] or 0

def save(self, force_insert=False, force_update=False, using=None, update_fields=None):
self.text = sanitise_string(self.text)
self.rating_text = sanitise_string(self.rating_text)
self.token_count = len(tokeniser.encode(self.text))
self.token_count = self.associated_file_token_count + len(tokeniser.encode(self.text))
super().save(force_insert, force_update, using, update_fields)
self.log()

Expand Down
27 changes: 27 additions & 0 deletions django_app/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

import pytest
from django.core.files.uploadedfile import SimpleUploadedFile
from freezegun import freeze_time
from lxml.html.diff import token
from pytz import utc

from redbox_app.redbox_core.models import (
ChatMessage,
Expand Down Expand Up @@ -35,3 +38,27 @@ def test_chat_message_model_token_count_on_save(chat):
assert not chat_message.token_count
chat_message.save()
assert chat_message.token_count == 4


@pytest.mark.django_db()
@pytest.mark.parametrize("role, expected_count", [(ChatMessage.Role.ai, 0), (ChatMessage.Role.user, 100)])
def test_associated_file_token_count(chat, original_file, role, expected_count):

now = datetime.now(tz=utc)

# Given a chat message...
with freeze_time(now):
chat_message = ChatMessage.objects.create(chat=chat, role=role, text="I am a message")

# and a file created before it...
with freeze_time(now - timedelta(seconds=1)):
File.objects.create(original_file=original_file, chat=chat, token_count=100)


# and a file created after it...
with freeze_time(now + timedelta(seconds=1)):
File.objects.create(original_file=original_file, chat=chat, token_count=200)

# when i call associated_file_token_count
# I expect to see the token count for the file created before it in the count
assert chat_message.associated_file_token_count == expected_count
14 changes: 7 additions & 7 deletions infrastructure/aws/ecs.tf
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module "cluster" {
# checkov:skip=CKV_TF_1: We're using semantic versions instead of commit hash
# source = "../../../i-ai-core-infrastructure//modules/ecs_cluster"
source = "git::https://github.com/i-dot-ai/i-dot-ai-core-terraform-modules.git//modules/infrastructure/ecs-cluster?ref=v1.0.0-ecs-cluster"
source = "git@github.com:i-dot-ai/i-dot-ai-core-terraform-modules.git//modules/infrastructure/ecs-cluster?ref=v1.0.0-ecs-cluster"
name = local.name
}

Expand Down Expand Up @@ -60,13 +60,13 @@ resource "aws_secretsmanager_secret_version" "django-app-json-secret" {
module "django-app" {
# checkov:skip=CKV_TF_1: We're using semantic versions instead of commit hash
#source = "../../i-dot-ai-core-terraform-modules//modules/infrastructure/ecs" # For testing local changes
source = "git::https://github.com/i-dot-ai/i-dot-ai-core-terraform-modules.git//modules/infrastructure/ecs?ref=v1.0.0-ecs"
source = "git@github.com:i-dot-ai/i-dot-ai-core-terraform-modules.git//modules/infrastructure/ecs?ref=v1.0.0-ecs"
memory = 4096
cpu = 2048
create_listener = true
create_networking = true
name = "${local.name}-django-app"
image_tag = var.image_tag
image_tag = "5ec4d891cac87760c61ca6c45f93b8e9e9411e56"
ecr_repository_uri = "${var.ecr_repository_uri}/${var.project_name}-django-app"
ecs_cluster_id = module.cluster.ecs_cluster_id
ecs_cluster_name = module.cluster.ecs_cluster_name
Expand Down Expand Up @@ -99,14 +99,14 @@ module "django-app" {
module "worker" {
# checkov:skip=CKV_TF_1: We're using semantic versions instead of commit hash
#source = "../../i-dot-ai-core-terraform-modules//modules/infrastructure/ecs" # For testing local changes
source = "git::https://github.com/i-dot-ai/i-dot-ai-core-terraform-modules.git//modules/infrastructure/ecs?ref=v1.0.0-ecs"
source = "git@github.com:i-dot-ai/i-dot-ai-core-terraform-modules.git//modules/infrastructure/ecs?ref=v1.0.0-ecs"
command = ["venv/bin/django-admin", "qcluster"]
memory = 6144
cpu = 2048
create_listener = false
create_networking = false
name = "${local.name}-worker"
image_tag = var.image_tag
image_tag = "5ec4d891cac87760c61ca6c45f93b8e9e9411e56"
ecr_repository_uri = "${var.ecr_repository_uri}/${var.project_name}-django-app"
ecs_cluster_id = module.cluster.ecs_cluster_id
ecs_cluster_name = module.cluster.ecs_cluster_name
Expand All @@ -131,14 +131,14 @@ module "worker" {
module "lit-ssr" {
# checkov:skip=CKV_TF_1: We're using semantic versions instead of commit hash
#source = "../../i-dot-ai-core-terraform-modules//modules/infrastructure/ecs" # For testing local changes
source = "git::https://github.com/i-dot-ai/i-dot-ai-core-terraform-modules.git//modules/infrastructure/ecs?ref=v1.0.0-ecs"
source = "git@github.com:i-dot-ai/i-dot-ai-core-terraform-modules.git//modules/infrastructure/ecs?ref=v1.0.0-ecs"
service_discovery_service_arn = aws_service_discovery_service.lit_ssr_service_discovery_service.arn
memory = 6144
cpu = 2048
create_listener = false
create_networking = false
name = "${local.name}-lit-ssr"
image_tag = var.image_tag
image_tag = "5ec4d891cac87760c61ca6c45f93b8e9e9411e56"
ecr_repository_uri = "${var.ecr_repository_uri}/${var.project_name}-lit-ssr"
ecs_cluster_id = module.cluster.ecs_cluster_id
ecs_cluster_name = module.cluster.ecs_cluster_name
Expand Down
4 changes: 2 additions & 2 deletions infrastructure/aws/load_balancer.tf
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module "load_balancer" {
# checkov:skip=CKV_TF_1: We're using semantic versions instead of commit hash
#source = "../../i-dot-ai-core-terraform-modules//modules/infrastructure/load_balancer" # For testing local changes
source = "git::https://github.com/i-dot-ai/i-dot-ai-core-terraform-modules.git//modules/infrastructure/load_balancer?ref=v1.0.0-load_balancer"
source = "git@github.com:i-dot-ai/i-dot-ai-core-terraform-modules.git//modules/infrastructure/load_balancer?ref=v1.0.0-load_balancer"
name = local.name
account_id = var.account_id
vpc_id = data.terraform_remote_state.vpc.outputs.vpc_id
Expand All @@ -16,7 +16,7 @@ module "load_balancer" {
module "waf" {
# checkov:skip=CKV_TF_1: We're using semantic versions instead of commit hash
#source = "../../i-dot-ai-core-terraform-modules//modules/infrastructure/waf" # For testing local changes
source = "git::https://github.com/i-dot-ai/i-dot-ai-core-terraform-modules.git//modules/infrastructure/waf?ref=v1.0.0-waf"
source = "git@github.com:i-dot-ai/i-dot-ai-core-terraform-modules.git//modules/infrastructure/waf?ref=v1.0.0-waf"
name = local.name
ip_set = concat(var.internal_ips, var.developer_ips, var.external_ips)
scope = var.scope
Expand Down
2 changes: 1 addition & 1 deletion infrastructure/aws/postgres/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ locals {
module "postgres" {
# checkov:skip=CKV_TF_1: We're using semantic versions instead of commit hash
# source = "../../../i-ai-core-infrastructure//modules/postgres"
source = "git::https://github.com/i-dot-ai/i-dot-ai-core-terraform-modules.git//modules/infrastructure/postgres?ref=v1.0.0-postgres"
source = "git@github.com:i-dot-ai/i-dot-ai-core-terraform-modules.git//modules/infrastructure/postgres?ref=v1.0.0-postgres"
kms_secrets_arn = data.terraform_remote_state.platform.outputs.kms_key_arn
name = local.name
db_name = "postgres"
Expand Down
2 changes: 1 addition & 1 deletion infrastructure/aws/rds.tf
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module "rds" {
# checkov:skip=CKV_TF_1: We're using semantic versions instead of commit hash
# source = "../../../i-dot-ai-core-terraform-modules//modules/infrastructure/rds" # For testing local changes
source = "git::https://github.com/i-dot-ai/i-dot-ai-core-terraform-modules.git//modules/infrastructure/rds?ref=v1.0.0-rds"
source = "git@github.com:i-dot-ai/i-dot-ai-core-terraform-modules.git//modules/infrastructure/rds?ref=v1.0.0-rds"
name = local.name
db_name = var.project_name
domain_name = var.domain_name
Expand Down

0 comments on commit 666c63c

Please sign in to comment.