Skip to content

Commit

Permalink
feat: multi doc support prompt studio (#729)
Browse files Browse the repository at this point in the history
* added support for image,docx

* added support for image,docx BE

* added doc to pdf converter

* Removed unwanted changes 

Signed-off-by: jagadeeswaran-zipstack <[email protected]>

* code refactor

* removed unwanted logs and code refactor

* removed unwanted code

* removed unwanted code

* removed plugin related logic

* code refactor

* added types

* code refactored

* merge conflict fix

---------

Signed-off-by: jagadeeswaran-zipstack <[email protected]>
Co-authored-by: vishnuszipstack <[email protected]>
  • Loading branch information
jagadeeswaran-zipstack and vishnuszipstack authored Jan 9, 2025
1 parent fc60cf0 commit b69ec16
Show file tree
Hide file tree
Showing 6 changed files with 270 additions and 97 deletions.
33 changes: 25 additions & 8 deletions backend/file_management/file_management_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,21 +139,31 @@ def upload_file(
# adding filename with path
file_path += file_name
with fs.open(file_path, mode="wb") as remote_file:
remote_file.write(file.read())
if isinstance(file, bytes):
remote_file.write(file)
else:
remote_file.write(file.read())

@staticmethod
@deprecated(reason="Use remote FS APIs from SDK")
def fetch_file_contents(file_system: UnstractFileSystem, file_path: str) -> Any:
def fetch_file_contents(
file_system: UnstractFileSystem,
file_path: str,
allowed_content_types: list[str],
) -> Any:
fs = file_system.get_fsspec_fs()

try:
file_info = fs.info(file_path)
except FileNotFoundError:
raise FileNotFound

file_content_type = file_info.get("ContentType")
file_type = file_info.get("type")

if file_type != "file":
raise InvalidFileType

try:
if not file_content_type:
file_content_type, _ = mimetypes.guess_type(file_path)
Expand All @@ -165,19 +175,26 @@ def fetch_file_contents(file_system: UnstractFileSystem, file_path: str) -> Any:
except ApiRequestError as exception:
logger.error(f"ApiRequestError from {file_info} {exception}")
raise ConnectorApiRequestError

data = ""
# Check if the file type is in the allowed list
if file_content_type not in allowed_content_types:
raise InvalidFileType(f"File type '{file_content_type}' is not allowed.")

# Handle allowed file types
if file_content_type == "application/pdf":
# Read contents of PDF file into a string
with fs.open(file_path, "rb") as file:
encoded_string = base64.b64encode(file.read())
return encoded_string
data = base64.b64encode(file.read())

elif file_content_type == "text/plain":
with fs.open(file_path, "r") as file:
logger.info(f"Reading text file: {file_path}")
text_content = file.read()
return text_content
data = file.read()

else:
raise InvalidFileType
logger.warning(f"File type '{file_content_type}' is not handled.")

return {"data": data, "mime_type": file_content_type}

@staticmethod
def _delete_file(fs, file_path):
Expand Down
14 changes: 10 additions & 4 deletions backend/prompt_studio/prompt_studio_core_v2/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from account_v2.models import User
from account_v2.serializer import UserSerializer
from django.core.exceptions import ObjectDoesNotExist
from file_management.constants import FileInformationKey
from prompt_studio.prompt_profile_manager_v2.models import ProfileManager
from prompt_studio.prompt_studio_core_v2.constants import ToolStudioKeys as TSKeys
from prompt_studio.prompt_studio_core_v2.exceptions import DefaultProfileError
Expand All @@ -23,6 +22,13 @@

logger = logging.getLogger(__name__)

try:
from plugins.processor.file_converter.constants import (
ExtentedFileInformationKey as FileKey,
)
except ImportError:
from file_management.constants import FileInformationKey as FileKey


class CustomToolSerializer(IntegrityErrorMixin, AuditSerializer):
shared_users = serializers.PrimaryKeyRelatedField(
Expand Down Expand Up @@ -151,10 +157,10 @@ class FileUploadIdeSerializer(serializers.Serializer):
required=True,
validators=[
FileValidator(
allowed_extensions=FileInformationKey.FILE_UPLOAD_ALLOWED_EXT,
allowed_mimetypes=FileInformationKey.FILE_UPLOAD_ALLOWED_MIME,
allowed_extensions=FileKey.FILE_UPLOAD_ALLOWED_EXT,
allowed_mimetypes=FileKey.FILE_UPLOAD_ALLOWED_MIME,
min_size=0,
max_size=FileInformationKey.FILE_UPLOAD_MAX_SIZE,
max_size=FileKey.FILE_UPLOAD_MAX_SIZE,
)
],
)
29 changes: 26 additions & 3 deletions backend/prompt_studio/prompt_studio_core_v2/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from django.db import IntegrityError
from django.db.models import QuerySet
from django.http import HttpRequest
from file_management.constants import FileInformationKey as FileKey
from file_management.exceptions import FileNotFound
from file_management.file_management_helper import FileManagerHelper
from permissions.permission import IsOwner, IsOwnerOrSharedUser
Expand Down Expand Up @@ -404,7 +405,14 @@ def fetch_contents_ide(self, request: HttpRequest, pk: Any = None) -> Response:
document: DocumentManager = DocumentManager.objects.get(pk=document_id)
file_name: str = document.document_name
view_type: str = serializer.validated_data.get("view_type")
file_converter = get_plugin_class_by_name(
name="file_converter",
plugins=self.processor_plugins,
)

allowed_content_types = FileKey.FILE_UPLOAD_ALLOWED_MIME
if file_converter:
allowed_content_types = file_converter.get_extented_file_information_key()
filename_without_extension = file_name.rsplit(".", 1)[0]
if view_type == FileViewTypes.EXTRACT:
file_name = (
Expand All @@ -430,7 +438,9 @@ def fetch_contents_ide(self, request: HttpRequest, pk: Any = None) -> Response:
file_path += file_name
# Temporary Hack for frictionless onboarding as the user id will be empty
try:
contents = FileManagerHelper.fetch_file_contents(file_system, file_path)
contents = FileManagerHelper.fetch_file_contents(
file_system, file_path, allowed_content_types
)
except FileNotFound:
file_path = file_path = (
FileManagerHelper.handle_sub_directory_for_tenants(
Expand Down Expand Up @@ -462,10 +472,23 @@ def upload_for_ide(self, request: HttpRequest, pk: Any = None) -> Response:
serializer = FileUploadIdeSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
uploaded_files: Any = serializer.validated_data.get("file")
file_converter = get_plugin_class_by_name(
name="file_converter",
plugins=self.processor_plugins,
)

documents = []
for uploaded_file in uploaded_files:
# Store file
file_name = uploaded_file.name
file_data = uploaded_file
file_type = uploaded_file.content_type
# Convert non-PDF files
if file_converter and file_type != "application/pdf":
file_data, file_name = file_converter.process_file(
uploaded_file, file_name
)

logger.info(
f"Uploading file: {file_name}" if file_name else "Uploading file"
)
Expand All @@ -480,15 +503,15 @@ def upload_for_ide(self, request: HttpRequest, pk: Any = None) -> Response:
FileManagerHelper.upload_file(
file_system,
file_path,
uploaded_file,
file_data,
file_name,
)
else:
PromptStudioFileHelper.upload_for_ide(
org_id=UserSessionUtils.get_organization_id(request),
user_id=custom_tool.created_by.user_id,
tool_id=str(custom_tool.tool_id),
uploaded_file=uploaded_file,
uploaded_file=file_data,
)

# Create a record in the db for the file
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ import { useEffect, useState } from "react";
import { useParams } from "react-router-dom";
import "./DocumentManager.css";

import { base64toBlob, docIndexStatus } from "../../../helpers/GetStaticData";
import {
base64toBlobWithMime,
docIndexStatus,
} from "../../../helpers/GetStaticData";
import { useAxiosPrivate } from "../../../hooks/useAxiosPrivate";
import { useCustomToolStore } from "../../../store/custom-tool-store";
import { useSessionStore } from "../../../store/session-store";
Expand All @@ -22,6 +25,7 @@ import { ManageDocsModal } from "../manage-docs-modal/ManageDocsModal";
import { PdfViewer } from "../pdf-viewer/PdfViewer";
import { TextViewerPre } from "../text-viewer-pre/TextViewerPre";
import usePostHogEvents from "../../../hooks/usePostHogEvents";
import { TextViewer } from "../text-viewer/TextViewer";

let items = [
{
Expand Down Expand Up @@ -103,6 +107,20 @@ function DocumentManager({ generateIndex, handleUpdateTool, handleDocChange }) {
const { id } = useParams();
const highlightData = selectedHighlight?.highlight || [];

const [blobFileUrl, setBlobFileUrl] = useState("");
const [fileData, setFileData] = useState({});

useEffect(() => {
// Convert blob URL to an object URL
if (fileData.blob) {
const objectUrl = URL.createObjectURL(fileData.blob);
setBlobFileUrl(objectUrl);

// Clean up the URL after component unmount
return () => URL.revokeObjectURL(objectUrl);
}
}, [fileData]);

useEffect(() => {
if (isSimplePromptStudio) {
items = [
Expand Down Expand Up @@ -199,7 +217,8 @@ function DocumentManager({ generateIndex, handleUpdateTool, handleDocChange }) {
getDocsFunc(details?.tool_id, selectedDoc?.document_id, viewType)
.then((res) => {
const data = res?.data?.data || "";
processGetDocsResponse(data, viewType);
const mimeType = res?.data?.mime_type || "";
processGetDocsResponse(data, viewType, mimeType);
})
.catch((err) => {
handleGetDocsError(err, viewType);
Expand All @@ -226,11 +245,19 @@ function DocumentManager({ generateIndex, handleUpdateTool, handleDocChange }) {
});
};

const processGetDocsResponse = (data, viewType) => {
const processGetDocsResponse = (data, viewType, mimeType) => {
if (viewType === viewTypes.original) {
const base64String = data || "";
const blob = base64toBlob(base64String);
setFileUrl(URL.createObjectURL(blob));
const blob = base64toBlobWithMime(base64String, mimeType);
setFileData({ blob, mimeType });
const reader = new FileReader();
reader.readAsDataURL(blob);
reader.onload = () => {
setFileUrl(reader.result);
};
reader.onerror = () => {
throw new Error("Fail to load the file");
};
} else if (viewType === viewTypes.extract) {
setExtractTxt(data);
}
Expand Down Expand Up @@ -317,6 +344,19 @@ function DocumentManager({ generateIndex, handleUpdateTool, handleDocChange }) {
}
};

const renderDoc = (docName, fileUrl, highlightData) => {
const fileType = docName?.split(".").pop().toLowerCase(); // Get the file extension
switch (fileType) {
case "pdf":
return <PdfViewer fileUrl={fileUrl} highlightData={highlightData} />;
case "txt":
case "md":
return <TextViewer fileUrl={fileUrl} />;
default:
return <div>Unsupported file type: {fileType}</div>;
}
};

return (
<div className="doc-manager-layout">
<div className="doc-manager-header">
Expand Down Expand Up @@ -388,7 +428,7 @@ function DocumentManager({ generateIndex, handleUpdateTool, handleDocChange }) {
setOpenManageDocsModal={setOpenManageDocsModal}
errMsg={fileErrMsg}
>
<PdfViewer fileUrl={fileUrl} highlightData={highlightData} />
{renderDoc(selectedDoc?.document_name, blobFileUrl, highlightData)}
</DocumentViewer>
)}
{activeKey === "2" && (
Expand Down
Loading

0 comments on commit b69ec16

Please sign in to comment.