diff --git a/backend/prompt_studio/prompt_studio_core_v2/static/select_choices.json b/backend/prompt_studio/prompt_studio_core_v2/static/select_choices.json index f9e002f7d..2e260a452 100644 --- a/backend/prompt_studio/prompt_studio_core_v2/static/select_choices.json +++ b/backend/prompt_studio/prompt_studio_core_v2/static/select_choices.json @@ -15,7 +15,8 @@ "boolean":"boolean", "json":"json", "table":"table", - "record":"record" + "record":"record", + "line_item":"line-item" }, "output_processing":{ "DEFAULT":"Default" diff --git a/backend/prompt_studio/prompt_studio_output_manager_v2/output_manager_helper.py b/backend/prompt_studio/prompt_studio_output_manager_v2/output_manager_helper.py index 4e6a25daa..f6ec9bf9d 100644 --- a/backend/prompt_studio/prompt_studio_output_manager_v2/output_manager_helper.py +++ b/backend/prompt_studio/prompt_studio_output_manager_v2/output_manager_helper.py @@ -148,7 +148,7 @@ def update_or_create_prompt_output( output = outputs.get(prompt.prompt_key) # TODO: use enums here - if prompt.enforce_type in {"json", "table", "record"}: + if prompt.enforce_type in {"json", "table", "record", "line-item"}: output = json.dumps(output) profile_manager = default_profile eval_metrics = outputs.get(f"{prompt.prompt_key}__evaluation", []) diff --git a/backend/prompt_studio/prompt_studio_v2/migrations/0006_alter_toolstudioprompt_enforce_type.py b/backend/prompt_studio/prompt_studio_v2/migrations/0006_alter_toolstudioprompt_enforce_type.py new file mode 100644 index 000000000..8dc5a79d4 --- /dev/null +++ b/backend/prompt_studio/prompt_studio_v2/migrations/0006_alter_toolstudioprompt_enforce_type.py @@ -0,0 +1,39 @@ +# Generated by Django 4.2.1 on 2025-01-09 21:09 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("prompt_studio_v2", "0005_alter_toolstudioprompt_required"), + ] + + operations = [ + migrations.AlterField( + model_name="toolstudioprompt", + name="enforce_type", + field=models.TextField( + blank=True, + choices=[ + ("Text", "Response sent as Text"), + ("number", "Response sent as number"), + ("email", "Response sent as email"), + ("date", "Response sent as date"), + ("boolean", "Response sent as boolean"), + ("json", "Response sent as json"), + ("table", "Response sent as table"), + ( + "record", + "Response sent for records. Entries of records are list of logical and organized individual entities with distint values", + ), + ( + "line-item", + "Response sent as line-item which is large a JSON output. If extraction stopped due to token limitation, we try to continue extraction from where it stopped", + ), + ], + db_comment="Field to store the type in which the response to be returned.", + default="Text", + ), + ), + ] diff --git a/backend/prompt_studio/prompt_studio_v2/models.py b/backend/prompt_studio/prompt_studio_v2/models.py index ccef50ac5..68f14d39f 100644 --- a/backend/prompt_studio/prompt_studio_v2/models.py +++ b/backend/prompt_studio/prompt_studio_v2/models.py @@ -27,6 +27,12 @@ class EnforceType(models.TextChoices): "logical and organized individual " "entities with distint values" ) + LINE_ITEM = "line-item", ( + "Response sent as line-item " + "which is large a JSON output. " + "If extraction stopped due to token limitation, " + "we try to continue extraction from where it stopped" + ) class PromptType(models.TextChoices): PROMPT = "PROMPT", "Response sent as Text" diff --git a/prompt-service/src/unstract/prompt_service/constants.py b/prompt-service/src/unstract/prompt_service/constants.py index 9d567b2de..88d33b08f 100644 --- a/prompt-service/src/unstract/prompt_service/constants.py +++ b/prompt-service/src/unstract/prompt_service/constants.py @@ -76,6 +76,7 @@ class PromptServiceContants: REQUIRED = "required" EXECUTION_SOURCE = "execution_source" METRICS = "metrics" + LINE_ITEM = "line-item" class RunLevel(Enum): diff --git a/prompt-service/src/unstract/prompt_service/helper.py b/prompt-service/src/unstract/prompt_service/helper.py index e451a2030..3797fcb52 100644 --- a/prompt-service/src/unstract/prompt_service/helper.py +++ b/prompt-service/src/unstract/prompt_service/helper.py @@ -28,6 +28,11 @@ from unstract.sdk.file_storage.constants import StorageType from unstract.sdk.file_storage.env_helper import EnvHelper +PAID_FEATURE_MSG = ( + "It is a cloud / enterprise feature. If you have purchased a plan and still " + "face this issue, please contact support" +) + load_dotenv() # Global variable to store plugins @@ -394,3 +399,82 @@ def extract_table( except table_extractor["exception_cls"] as e: msg = f"Couldn't extract table. {e}" raise APIError(message=msg) + + +def extract_line_item( + tool_settings: dict[str, Any], + output: dict[str, Any], + plugins: dict[str, dict[str, Any]], + structured_output: dict[str, Any], + llm: LLM, + file_path: str, + metadata: Optional[dict[str, str]], + execution_source: str, +) -> dict[str, Any]: + line_item_extraction_plugin: dict[str, Any] = plugins.get( + "line-item-extraction", {} + ) + if not line_item_extraction_plugin: + raise APIError(PAID_FEATURE_MSG) + + extract_file_path = file_path + if execution_source == ExecutionSource.IDE.value: + # Adjust file path to read from the extract folder + base_name = os.path.splitext(os.path.basename(file_path))[0] + extract_file_path = os.path.join( + os.path.dirname(file_path), "extract", f"{base_name}.txt" + ) + + # Read file content into context + if check_feature_flag_status(FeatureFlag.REMOTE_FILE_STORAGE): + fs_instance: FileStorage = FileStorage(FileStorageProvider.LOCAL) + if execution_source == ExecutionSource.IDE.value: + fs_instance = EnvHelper.get_storage( + storage_type=StorageType.PERMANENT, + env_name=FileStorageKeys.PERMANENT_REMOTE_STORAGE, + ) + if execution_source == ExecutionSource.TOOL.value: + fs_instance = EnvHelper.get_storage( + storage_type=StorageType.TEMPORARY, + env_name=FileStorageKeys.TEMPORARY_REMOTE_STORAGE, + ) + + if not fs_instance.exists(extract_file_path): + raise FileNotFoundError( + f"The file at path '{extract_file_path}' does not exist." + ) + context = fs_instance.read(path=extract_file_path, encoding="utf-8", mode="rb") + else: + if not os.path.exists(extract_file_path): + raise FileNotFoundError( + f"The file at path '{extract_file_path}' does not exist." + ) + + with open(extract_file_path, encoding="utf-8") as file: + context = file.read() + + prompt = construct_prompt( + preamble=tool_settings.get(PSKeys.PREAMBLE, ""), + prompt=output["promptx"], + postamble=tool_settings.get(PSKeys.POSTAMBLE, ""), + grammar_list=tool_settings.get(PSKeys.GRAMMAR, []), + context=context, + platform_postamble="", + ) + + try: + line_item_extraction = line_item_extraction_plugin["entrypoint_cls"]( + llm=llm, + tool_settings=tool_settings, + output=output, + prompt=prompt, + structured_output=structured_output, + logger=current_app.logger, + ) + answer = line_item_extraction.run() + structured_output[output[PSKeys.NAME]] = answer + metadata[PSKeys.CONTEXT][output[PSKeys.NAME]] = [context] + return structured_output + except line_item_extraction_plugin["exception_cls"] as e: + msg = f"Couldn't extract table. {e}" + raise APIError(message=msg) diff --git a/prompt-service/src/unstract/prompt_service/main.py b/prompt-service/src/unstract/prompt_service/main.py index 7432b0a40..a66d2778f 100644 --- a/prompt-service/src/unstract/prompt_service/main.py +++ b/prompt-service/src/unstract/prompt_service/main.py @@ -12,6 +12,7 @@ from unstract.prompt_service.exceptions import APIError, ErrorResponse, NoPayloadError from unstract.prompt_service.helper import ( construct_and_run_prompt, + extract_line_item, extract_table, extract_variable, get_cleaned_context, @@ -264,6 +265,44 @@ def prompt_processor() -> Any: "Error while extracting table for the prompt", ) raise api_error + elif output[PSKeys.TYPE] == PSKeys.LINE_ITEM: + try: + structured_output = extract_line_item( + tool_settings=tool_settings, + output=output, + plugins=plugins, + structured_output=structured_output, + llm=llm, + file_path=file_path, + metadata=metadata, + execution_source=execution_source, + ) + metadata = query_usage_metadata(token=platform_key, metadata=metadata) + # TODO: Handle metrics for line-item extraction + response = { + PSKeys.METADATA: metadata, + PSKeys.OUTPUT: structured_output, + PSKeys.METRICS: metrics, + } + continue + except APIError as e: + app.logger.error( + "Failed to extract line-item for the prompt %s: %s", + output[PSKeys.NAME], + str(e), + ) + publish_log( + log_events_id, + { + "tool_id": tool_id, + "prompt_key": prompt_name, + "doc_name": doc_name, + }, + LogLevel.ERROR, + RunLevel.RUN, + "Error while extracting line-item for the prompt", + ) + raise e try: if chunk_size == 0: